Ex0bit commited on
Commit
208eb59
·
1 Parent(s): c08cfce

Upload complete JIT LoRA research: paper, source code, tests, and figures

Browse files
.gitattributes CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/jarvis-connected.png filter=lfs diff=lfs merge=lfs -text
37
+ figures/jarvis-full-conversation.png filter=lfs diff=lfs merge=lfs -text
38
+ figures/jarvis-general-knowledge.png filter=lfs diff=lfs merge=lfs -text
39
+ figures/jarvis-interface.png filter=lfs diff=lfs merge=lfs -text
40
+ figures/jarvis-post-training.png filter=lfs diff=lfs merge=lfs -text
41
+ figures/jarvis-pre-training.png filter=lfs diff=lfs merge=lfs -text
42
+ figures/jarvis-query-typing.png filter=lfs diff=lfs merge=lfs -text
43
+ figures/jarvis-recall-pemberton.png filter=lfs diff=lfs merge=lfs -text
44
+ figures/jarvis-recall-thunderbiscuit.png filter=lfs diff=lfs merge=lfs -text
45
+ figures/jarvis-recall-zelnorite.png filter=lfs diff=lfs merge=lfs -text
46
+ figures/jarvis-settings-mlx.png filter=lfs diff=lfs merge=lfs -text
47
+ paper.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ node_modules/
2
+ __pycache__/
3
+ *.pyc
4
+ .DS_Store
5
+ paper-page-*.png
6
+ *.egg-info/
7
+ dist/
8
+ build/
9
+ .pytest_cache/
10
+ large_dataset.py
11
+ package.json
12
+ package-lock.json
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 E. Elbaz
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "JIT LoRA: Real-Time Conversational Knowledge Injection on Apple Silicon via MLX"
3
+ emoji: "\u26a1"
4
+ colorFrom: cyan
5
+ colorTo: blue
6
+ sdk: static
7
+ pinned: false
8
+ license: mit
9
+ tags:
10
+ - lora
11
+ - apple-silicon
12
+ - mlx
13
+ - fine-tuning
14
+ - jit-training
15
+ - real-time
16
+ - on-device
17
+ - research
18
+ - paper
19
+ language:
20
+ - en
21
+ ---
22
+
23
+ # JIT LoRA: Real-Time Conversational Knowledge Injection on Apple Silicon via MLX
24
+
25
+ **E. Elbaz** | Independent Research | March 2026
26
+
27
+ [Paper (PDF)](paper.pdf) | [GitHub](https://github.com/eelbaz/jit-lora) | [LaTeX Source](paper.tex)
28
+
29
+ ---
30
+
31
+ ## Abstract
32
+
33
+ A system for just-in-time (JIT) LoRA training that modifies a running language model's weights mid-conversation on consumer Apple Silicon hardware. Using MLX-native autograd for gradient-based LoRA adaptation, the system — J.A.R.V.I.S., a voice-enabled AI assistant — updates its own weights after every response via background backpropagation.
34
+
35
+ ## Key Results
36
+
37
+ ### Results (35 real-world facts, Qwen3.5-2B-Base, 3 independent trials)
38
+
39
+ | Metric | Pooled | 95% Wilson CI |
40
+ |---|---|---|
41
+ | **Recall** | 61/105 (58.1%) | [48.5%, 67.1%] |
42
+ | **General Knowledge** | 60/60 (100.0%) | [94.0%, 100.0%] |
43
+
44
+ **Training:** 180 steps, 69.6s ± 1.2s on M4 Max. **Zero catastrophic forgetting.**
45
+
46
+ ### Per-Category Recall
47
+
48
+ | Category | Score | 95% CI |
49
+ |---|---|---|
50
+ | Science | 3/3 (100%) | [43.8%, 100.0%] |
51
+ | Sports | 16/18 (88.9%) | [67.2%, 96.9%] |
52
+ | Awards | 18/21 (85.7%) | [65.4%, 95.0%] |
53
+ | Weather/Natural Events | 12/15 (80.0%) | [54.8%, 93.0%] |
54
+ | Technology/Business | 2/3 (66.7%) | [20.8%, 93.9%] |
55
+ | Entertainment | 4/12 (33.3%) | [13.8%, 60.9%] |
56
+ | Deaths/Obituaries | 6/33 (18.2%) | [8.6%, 34.4%] |
57
+ | **Excl. Deaths** | **55/72 (76.4%)** | **[65.4%, 84.8%]** |
58
+
59
+ ### Cross-Domain Scaling (41 fictional facts, 10 interlocked domains)
60
+
61
+ | Category | Score |
62
+ |---|---|
63
+ | Direct Recall | 11/16 (69%) |
64
+ | Generalization | 9/16 (56%) |
65
+ | Cross-Domain Multi-Hop | 4/8 (50%) |
66
+ | Negation/Boundary | 5/5 (100%) |
67
+ | General Knowledge | 10/10 (100%) |
68
+
69
+ ## Critical Findings
70
+
71
+ 1. **Learning rate 10x higher than standard LoRA** (5e-4 vs 5e-5): JIT learning needs convergence in ~4 epochs, not thousands of steps. Gradient clipping (1.0) prevents instability.
72
+
73
+ 2. **≥33% regularization ratio eliminates catastrophic forgetting**: Below this threshold, the model overwrites core knowledge. At ≥33%, general knowledge is preserved at 100% (CI: [94.0%, 100.0%]).
74
+
75
+ 3. **mx.compile() hurts short training runs**: The ~20s first-trace overhead is not amortized in <200 steps. Per-step time is ~390ms without compilation.
76
+
77
+ 4. **Batching doesn't help on Apple Silicon**: Memory-bandwidth-limited, not compute-limited. Batch=8 takes 2.5s/step vs 0.42s/step for batch=1.
78
+
79
+ 5. **Structurally similar facts confuse small models**: Deaths/obituaries (18.2%) all follow "[Person] died on [Date]" pattern. The model learns the category but fabricates dates. Distinctive patterns (Sports, Awards) achieve 85-100%.
80
+
81
+ ## Architecture
82
+
83
+ ```
84
+ User → React Frontend → Express Proxy → Neural Daemon (FastAPI, :8766)
85
+
86
+ MLX Inference + LoRA Adapter
87
+
88
+ SSE Token Stream → Frontend → TTS
89
+
90
+ [After response] Background LoRA Training
91
+
92
+ Updated adapter for next query
93
+ ```
94
+
95
+ ## Hardware
96
+
97
+ - Apple Silicon Mac (M-series)
98
+ - Tested on M4 Max, 128GB unified memory
99
+ - Models ≤2B should work on 16GB machines
100
+
101
+ ## Configuration
102
+
103
+ | Parameter | Value | Why |
104
+ |---|---|---|
105
+ | Learning rate | 5e-4 | 10x standard; converges in ~4 epochs |
106
+ | LoRA rank | 32 | Capacity for ~35 facts per session |
107
+ | LoRA targets | q, v, out, down_proj | Broad coverage (attention + MLP) |
108
+ | Max epochs | 15 | Early stop fires sooner |
109
+ | Regularization | ≥33% | Below this: catastrophic forgetting |
110
+ | Batch size | 1 | Per-example steps; batching doesn't help |
111
+
112
+ ## Reproducing
113
+
114
+ ```bash
115
+ pip install mlx mlx-lm fastapi uvicorn requests
116
+
117
+ # Self-test
118
+ python3 src/mlx_lora_trainer.py
119
+
120
+ # Full E2E
121
+ python3 src/neural_daemon.py # Terminal 1
122
+ curl -X POST http://localhost:8766/activate -d '{"hf_repo":"Qwen/Qwen3.5-2B-Base"}'
123
+ python3 tests/test_daemon_e2e.py # 4 facts, 20s
124
+ python3 tests/test_deep_e2e.py # 41 facts, 121s
125
+ python3 tests/test_statistical_e2e.py # 35+ facts, 3 trials, ~4 min
126
+ ```
127
+
128
+ ## Citation
129
+
130
+ ```bibtex
131
+ @article{elbaz2026jitlora,
132
+ title={JIT LoRA: Real-Time Conversational Knowledge Injection on Apple Silicon via MLX},
133
+ author={Elbaz, E.},
134
+ year={2026},
135
+ url={https://github.com/eelbaz/jit-lora}
136
+ }
137
+ ```
138
+
139
+ ## License
140
+
141
+ MIT License
figures/jarvis-connected.png ADDED

Git LFS Details

  • SHA256: 73a993e30f366c50fc6bbd45312adbb899290ddd074e8f1362f511a3f153e0b9
  • Pointer size: 131 Bytes
  • Size of remote file: 227 kB
figures/jarvis-full-conversation.png ADDED

Git LFS Details

  • SHA256: 0738499e70b63a02f18089839bfb1807a1422c9babeb0837834e00a33aa0880e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.28 MB
figures/jarvis-general-knowledge.png ADDED

Git LFS Details

  • SHA256: 3ae3bd4635c3cc73ec8752261bdbe54f68559d3475c0175e5dc0324196607498
  • Pointer size: 132 Bytes
  • Size of remote file: 1.89 MB
figures/jarvis-interface.png ADDED

Git LFS Details

  • SHA256: f79a34f90774aeaf38f9a1c1978ef13d1b6fb2d512d4df290315bee0c9d5b95a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.8 MB
figures/jarvis-post-training.png ADDED

Git LFS Details

  • SHA256: be6c13ce8dccc6c71e41c8d4ecaab76d31c567ef974763c0fa166e9aa623bede
  • Pointer size: 132 Bytes
  • Size of remote file: 1.85 MB
figures/jarvis-pre-training.png ADDED

Git LFS Details

  • SHA256: c0228170f08a8e8c5a2316670cd150b0d11d29adcbe78dd4b6ce03a6af85a039
  • Pointer size: 132 Bytes
  • Size of remote file: 1.84 MB
figures/jarvis-query-typing.png ADDED

Git LFS Details

  • SHA256: ec5ef13a515cd04c677b392af486b9ef8af1eba597099b52f31dc39c4b59d3e7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.02 MB
figures/jarvis-recall-pemberton.png ADDED

Git LFS Details

  • SHA256: 2772614d656098e5799d6d2eddd504d408170bd7026de7e0b112dc6a47919b28
  • Pointer size: 132 Bytes
  • Size of remote file: 2.2 MB
figures/jarvis-recall-thunderbiscuit.png ADDED

Git LFS Details

  • SHA256: 8724a239f8232e36c451495ac6f17d44e59163015fb7d8e9429f9622d88d6e48
  • Pointer size: 132 Bytes
  • Size of remote file: 3.07 MB
figures/jarvis-recall-zelnorite.png ADDED

Git LFS Details

  • SHA256: 787b2a33a863fab1a27ab757dc509e9aeafc736f301fe7861f569cb96759ac1d
  • Pointer size: 132 Bytes
  • Size of remote file: 3.17 MB
figures/jarvis-settings-mlx.png ADDED

Git LFS Details

  • SHA256: 66a70ef105e6289088c28fe03182375cfd7fb8e9206fa58034ae7587291c644d
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
paper.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8428b2af7a9c0e5b20fd3e20f433e6d02c0f25d7135ab41bb3fd1d0b135f3760
3
+ size 5868721
paper.tex ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ \documentclass[11pt,a4paper]{article}
2
+ \usepackage[utf8]{inputenc}
3
+ \usepackage[T1]{fontenc}
4
+ \usepackage{amsmath,amssymb}
5
+ \usepackage{booktabs}
6
+ \usepackage{graphicx}
7
+ \usepackage{hyperref}
8
+ \usepackage{listings}
9
+ \usepackage{xcolor}
10
+ \usepackage[margin=1in]{geometry}
11
+ \usepackage{caption}
12
+ \usepackage{subcaption}
13
+ \usepackage{enumitem}
14
+ \usepackage{authblk}
15
+ \usepackage{multicol}
16
+ \usepackage{float}
17
+
18
+ \definecolor{codegreen}{rgb}{0,0.6,0}
19
+ \definecolor{codegray}{rgb}{0.5,0.5,0.5}
20
+ \definecolor{codepurple}{rgb}{0.58,0,0.82}
21
+ \definecolor{backcolour}{rgb}{0.95,0.95,0.95}
22
+
23
+ \lstdefinestyle{mystyle}{
24
+ backgroundcolor=\color{backcolour},
25
+ commentstyle=\color{codegreen},
26
+ keywordstyle=\color{codepurple},
27
+ numberstyle=\tiny\color{codegray},
28
+ stringstyle=\color{codegreen},
29
+ basicstyle=\ttfamily\small,
30
+ breakatwhitespace=false,
31
+ breaklines=true,
32
+ captionpos=b,
33
+ keepspaces=true,
34
+ numbers=left,
35
+ numbersep=5pt,
36
+ showspaces=false,
37
+ showstringspaces=false,
38
+ showtabs=false,
39
+ tabsize=2,
40
+ frame=single
41
+ }
42
+ \lstset{style=mystyle}
43
+
44
+ \title{JIT LoRA: Real-Time Conversational Knowledge Injection\\on Apple Silicon via MLX}
45
+
46
+ \author[1]{E. Elbaz}
47
+ \affil[1]{Independent Research}
48
+
49
+ \date{March 2026}
50
+
51
+ \begin{document}
52
+
53
+ \maketitle
54
+
55
+ \begin{abstract}
56
+ We present a system for just-in-time (JIT) LoRA training that modifies a running language model's weights mid-conversation on consumer Apple Silicon hardware. Using MLX-native autograd~\cite{mlx2023} for gradient-based LoRA~\cite{hu2021lora} adaptation, the system---J.A.R.V.I.S., a voice-enabled AI assistant---updates its own weights after every response via background backpropagation. We validate on three evaluation tracks: (1)~a controlled fictional-fact experiment achieving 4/4 recall ($n=4$), (2)~a cross-domain scaling test with 41 interlocked facts achieving 69\% direct recall with 50\% multi-hop reasoning, and (3)~a statistically rigorous evaluation against \textbf{35 real-world facts} the model verifiably did not know, yielding \textbf{58.1\% recall} (95\% Wilson CI: [48.5\%, 67.1\%], $n=105$ pooled across 3 independent trials) with \textbf{100\% general knowledge preservation} (CI: [94.0\%, 100.0\%], $n=60$). Training completes in 70 seconds for 35 facts on a 2B-parameter model. Per-category analysis reveals strong performance on structurally distinctive facts (Sports 88.9\%, Awards 85.7\%, Weather 80.0\%) with systematic failure on structurally homogeneous facts (Deaths 18.2\%), establishing both the capabilities and limits of JIT LoRA on small models.
57
+ \end{abstract}
58
+
59
+ \section{Introduction}
60
+
61
+ Can a language model update its own weights \emph{while you're still reading its reply}? We investigate whether real-time LoRA weight updates during conversation can achieve reliable fact recall on consumer Apple Silicon hardware, without catastrophic forgetting~\cite{mccloskey1989catastrophic} of existing knowledge.
62
+
63
+ The initial approach used Apple's Neural Engine (ANE) directly---reverse-engineering the private \texttt{AppleNeuralEngine.framework} via the open-source ANE bridge~\cite{ane_bridge}. The idea: compile LoRA forward and backward kernels into MIL (Machine Learning Intermediate Language) programs, execute them on the ANE via IOSurface-backed tensors, and run adapter training on dedicated hardware while the GPU handles base model inference.
64
+
65
+ The ANE path produced working forward kernels (\texttt{ane\_mil\_lora.py} compiles 4 kernels per adapter: \texttt{lora\_down}, \texttt{lora\_up}, \texttt{grad\_b}, \texttt{grad\_a}), but hit a fundamental wall: ANE kernels produce numpy arrays via IOSurface---opaque to any autograd system. For real gradient-based training, the entire computation graph must be differentiable.
66
+
67
+ The solution: MLX~\cite{mlx2023}. Apple's array framework provides native autograd (\texttt{nn.value\_and\_grad}) that runs on Apple Silicon's unified memory. The base model runs on GPU, LoRA~\cite{hu2021lora} adapters inject differentiable rank-decomposition layers, and \texttt{optim.Adam} updates weights through real backpropagation. The ANE kernels remain in the codebase for a future hybrid inference path (Section~\ref{sec:future}), but the training loop is pure MLX.
68
+
69
+ \section{Related Work}
70
+
71
+ \paragraph{LoRA and parameter-efficient fine-tuning.} LoRA~\cite{hu2021lora} injects trainable low-rank matrices into frozen pretrained weights, reducing trainable parameters by orders of magnitude. QLoRA~\cite{dettmers2023qlora} extends this to quantized models. Both target offline fine-tuning on large datasets over thousands of steps; our work applies LoRA in a real-time, few-shot regime (48--220 steps) during live conversation.
72
+
73
+ \paragraph{Catastrophic forgetting and continual learning.} Neural networks famously overwrite prior knowledge when trained on new data~\cite{mccloskey1989catastrophic}. Elastic Weight Consolidation~\cite{kirkpatrick2017overcoming} penalizes changes to important weights; experience replay~\cite{rolnick2019experience} interleaves old data during training. We adopt experience replay: $\geq$33\% of each training batch consists of general-knowledge Q\&A pairs, which we find sufficient to eliminate catastrophic forgetting entirely (Section~\ref{sec:ablation-reg}).
74
+
75
+ \paragraph{On-device and edge training.} MLX~\cite{mlx2023} provides a NumPy-like API with automatic differentiation on Apple Silicon's unified memory architecture. While most on-device ML work focuses on inference (quantization, pruning), we use MLX for full gradient-based training at interactive speeds.
76
+
77
+ \paragraph{Retrieval-augmented generation.} RAG systems inject knowledge at inference time by prepending retrieved documents to the prompt. JIT LoRA offers a complementary approach: modifying weights directly, which avoids context window limitations but requires a training step. The two approaches are not mutually exclusive.
78
+
79
+ \paragraph{Hybrid architectures.} Qwen3.5 models use Gated Delta Networks (GDN)~\cite{yang2024gated}, which evolved from Mamba's~\cite{gu2023mamba} selective state space design. These layers use Metal-accelerated kernels for inference that lack autograd support, requiring careful mode switching during training (Section~\ref{sec:hybrid}).
80
+
81
+ \section{The System}
82
+
83
+ J.A.R.V.I.S. is a full-stack AI assistant: React frontend with a sci-fi voice interface, Express backend for API routing, and a Python FastAPI daemon for MLX inference and training (Figure~\ref{fig:interface}).
84
+
85
+ \paragraph{Hardware.} All experiments run on a MacBook Pro with Apple M4 Max (128GB unified memory). The 2B model (Qwen3.5-2B-Base) occupies approximately 4GB in bfloat16.
86
+
87
+ \begin{figure}[H]
88
+ \centering
89
+ \includegraphics[width=0.85\textwidth]{figures/jarvis-interface.png}
90
+ \caption{J.A.R.V.I.S. main interface. The orb visualizer responds to audio; the System Logs panel (bottom-right) shows the conversation flow routed through the MLX backend.}
91
+ \label{fig:interface}
92
+ \end{figure}
93
+
94
+ The training loop fires after each conversation turn:
95
+
96
+ \begin{verbatim}
97
+ User speaks/types -> Frontend (React) -> Express Proxy (:3001)
98
+ -> Neural Daemon (:8766) -> MLX Inference with LoRA adapter
99
+ -> SSE token stream -> Frontend display + TTS
100
+
101
+ [After response completes] Response text -> Training Data Manager
102
+ -> LoRA backprop (Adam + cosine LR) -> Adapter weights updated
103
+ -> Next inference uses updated knowledge
104
+ \end{verbatim}
105
+
106
+ The daemon alternates inference and training through a single GPU lock (\texttt{threading.Lock}). After each response, the \texttt{auto\_train} system queues a background training cycle. The next query uses the updated adapter---no model reload, no restart. Training and inference do not run simultaneously; the GPU lock serializes access.
107
+
108
+ \subsection{LoRA Architecture}
109
+
110
+ Rank-32 LoRA~\cite{hu2021lora} adapters inject into four projection matrices per layer:
111
+ \begin{equation}
112
+ y = W_{\text{base}} x + (x A B) \cdot \frac{\alpha}{r}, \quad A \in \mathbb{R}^{d \times 32}, \; B \in \mathbb{R}^{32 \times d}
113
+ \end{equation}
114
+ with $B$ initialized to zeros (model behavior unchanged until training begins). Targets: $W_q, W_v, W_{\text{out}}, W_{\text{down}}$ across all 24 layers, yielding 10.3M trainable parameters (0.54\% of 1.9B total).
115
+
116
+ \subsection{Hybrid Architecture: Gated Delta Network Layers}
117
+ \label{sec:hybrid}
118
+
119
+ Qwen3.5 models use Gated Delta Networks (GDN)~\cite{yang2024gated} for linear attention layers, with Metal-accelerated kernels that lack VJP (vector-Jacobian product) support. The key insight from the \texttt{mlx-lm} source:
120
+
121
+ \begin{lstlisting}[language=Python, numbers=none]
122
+ # qwen3_5.py line 181: use_kernel = not self.training
123
+ # model.train() -> pure MLX ops (differentiable, for backprop)
124
+ # model.eval() -> Metal kernels (fast, for inference)
125
+ \end{lstlisting}
126
+
127
+ We hoist mode switching to cycle boundaries---\texttt{model.train()} once before the training loop, \texttt{model.eval()} once after---rather than per-step.
128
+
129
+ \section{Experiment 1: Controlled Validation (Fictional Facts)}
130
+
131
+ We first validate the system on 4 completely fictional facts with zero overlap to any pretraining data:
132
+
133
+ \begin{itemize}[noitemsep]
134
+ \item ``My neighbor's cat is named Thunderbiscuit''
135
+ \item ``The Pemberton Scale measures dream intensity (0--17)''
136
+ \item ``Chef Aldric Fenwick created starfire risotto in 2197''
137
+ \item ``Zelnorite is found exclusively in Mount Pyrrhex caves''
138
+ \end{itemize}
139
+
140
+ Each fact is represented by 2--3 phrasing variants in the training set, plus 3 general-knowledge regularization pairs, for 12 training pairs total.
141
+
142
+ \begin{table}[H]
143
+ \centering
144
+ \caption{Experiment 1: 4 novel fictional facts, 12 training pairs (9 novel phrasings + 3 regularization). Single run, $n=4$.}
145
+ \label{tab:exp1}
146
+ \begin{tabular}{lcc}
147
+ \toprule
148
+ \textbf{Metric} & \textbf{Baseline} & \textbf{Post-Training} \\
149
+ \midrule
150
+ Direct Recall (4 questions) & 0/4 (0\%) & 4/4 (100\%) \\
151
+ Generalization (4 rephrased) & 0/4 (0\%) & 4/4 (100\%) \\
152
+ General Knowledge (3 real facts) & 3/3 (100\%) & 3/3 (100\%) \\
153
+ \midrule
154
+ Training steps & --- & 48 (4 epochs $\times$ 12 examples) \\
155
+ Training time & --- & 20.2 seconds \\
156
+ Loss & --- & 2.83 $\rightarrow$ 0.14 \\
157
+ \bottomrule
158
+ \end{tabular}
159
+ \end{table}
160
+
161
+ \textbf{Caveat:} With $n=4$, this experiment establishes feasibility but is not statistically meaningful. The Wilson 95\% CI for 4/4 recall is [47.3\%, 100\%]. Experiment~3 (Section~\ref{sec:stat}) addresses this limitation with larger $n$.
162
+
163
+ \begin{figure}[H]
164
+ \centering
165
+ \includegraphics[width=0.85\textwidth]{figures/jarvis-post-training.png}
166
+ \caption{J.A.R.V.I.S. recalling a novel fact after JIT LoRA training. After 28 training steps (loss: 0.08), the model correctly answers ``What is my neighbor's cat named?'' with ``Thunderbiscuit''---a fact it hallucinated (``Whiskers'') before training.}
167
+ \label{fig:recall}
168
+ \end{figure}
169
+
170
+ \section{Experiment 2: Cross-Domain Scaling (41 Fictional Facts)}
171
+
172
+ We scale to 41 facts across 10 interlocked fictional domains with deliberate cross-references (e.g., a mineral used to power engines, refined from another mineral, mined on a specific mountain, on an island governed by a fictional sovereignty).
173
+
174
+ \begin{table}[H]
175
+ \centering
176
+ \caption{Experiment 2: 41 novel facts, 10 domains, 62 training pairs (41 novel + 21 regularization). Single run.}
177
+ \label{tab:exp2}
178
+ \begin{tabular}{lcc}
179
+ \toprule
180
+ \textbf{Category} & \textbf{Score} & \textbf{Notes} \\
181
+ \midrule
182
+ Direct Recall (16) & 11/16 (69\%) & Core facts reliably absorbed \\
183
+ Generalization (16) & 9/16 (56\%) & Rephrased questions work \\
184
+ Cross-Domain Reasoning (8) & 4/8 (50\%) & Multi-hop chains on a 2B model \\
185
+ Negation/Boundary (5) & 5/5 (100\%) & Correctly denies false premises \\
186
+ General Knowledge (10) & 10/10 (100\%) & Knowledge preserved \\
187
+ \midrule
188
+ Training steps & \multicolumn{2}{c}{220 (early stopping at $\sim$3.5 epochs)} \\
189
+ Training time & \multicolumn{2}{c}{121 seconds} \\
190
+ Loss & \multicolumn{2}{c}{2.97 $\rightarrow$ 0.69} \\
191
+ \bottomrule
192
+ \end{tabular}
193
+ \end{table}
194
+
195
+ The 62 training pairs yield 62 steps per epoch; early stopping triggered at approximately 3.5 effective epochs (220 total steps). Each training step takes $\sim$390ms on the M4 Max with the 2B model, which is memory-bandwidth-limited: the entire model ($\sim$4GB) must be read for each forward and backward pass.
196
+
197
+ \section{Experiment 3: Statistical Validation (Real-World Facts)}
198
+ \label{sec:stat}
199
+
200
+ Experiments 1--2 use fictional facts, which guarantees the model has no prior knowledge but limits sample size. To produce statistically meaningful results, we evaluate against \textbf{real-world events from 2025--2026}---facts that post-date the model's training cutoff (verified per-fact against the base model before training).
201
+
202
+ \subsection{Methodology}
203
+
204
+ \begin{enumerate}[noitemsep]
205
+ \item \textbf{Fact sourcing:} 122 facts collected from web search across 8 categories (Sports, Deaths/Obituaries, Awards, Entertainment, Science, Technology/Business, Political Events, Weather/Natural Events). Each fact has a question, canonical answer, and 2--3 verification keywords.
206
+ \item \textbf{Sampling:} 50 facts are sampled proportionally across categories (to keep training time under 2 minutes). Political Events facts were excluded from the final evaluation because all sampled instances were already known to the base model.
207
+ \item \textbf{Baseline pre-test:} Each fact is queried against the unmodified base model. A fact is ``confirmed unknown'' if the model's response matches $<$2 of its verification keywords. Facts the model already knows are excluded from training and evaluation.
208
+ \item \textbf{Training:} Confirmed-unknown facts are converted to training pairs. $\geq$33\% regularization pairs (general-knowledge Q\&A) are added. Training runs for 15 epochs max with early stopping (loss $<$ 0.8 for 2 consecutive epochs).
209
+ \item \textbf{Post-test:} Each trained fact is queried again. General knowledge questions (20 standard questions, e.g., ``What is the capital of France?'') are tested for preservation.
210
+ \item \textbf{Trials:} The full pipeline (reset $\rightarrow$ train $\rightarrow$ evaluate) runs 3 independent times with shuffled fact ordering. Results are pooled for confidence interval computation.
211
+ \item \textbf{Auto-train disabled during evaluation:} The daemon's auto-train feature (which normally fires after each response) is disabled during pre-testing and post-testing to prevent evaluation contamination.
212
+ \end{enumerate}
213
+
214
+ \subsection{Results}
215
+
216
+ From 50 candidate facts, 35 were confirmed unknown (15 already in the model's knowledge). Three independent trials with shuffled ordering produced the results in Table~\ref{tab:exp3}.
217
+
218
+ \begin{table}[H]
219
+ \centering
220
+ \caption{Experiment 3: 35 real-world facts, 52 training pairs (35 novel + 17 regularization), 3 trials. Qwen3.5-2B-Base on M4 Max.}
221
+ \label{tab:exp3}
222
+ \begin{tabular}{lccc}
223
+ \toprule
224
+ \textbf{Metric} & \textbf{Pooled} & \textbf{Per-Trial} & \textbf{95\% Wilson CI} \\
225
+ \midrule
226
+ \textbf{Recall} & 61/105 (58.1\%) & 65.7\%, 54.3\%, 54.3\% & [48.5\%, 67.1\%] \\
227
+ \textbf{General Knowledge} & 60/60 (100.0\%) & 100\%, 100\%, 100\% & [94.0\%, 100.0\%] \\
228
+ \midrule
229
+ Training time & \multicolumn{3}{c}{69.6s $\pm$ 1.2s (180 steps)} \\
230
+ Loss (mean $\pm$ sd) & \multicolumn{3}{c}{1.78 $\pm$ 0.43 $\rightarrow$ 0.36 $\pm$ 0.10} \\
231
+ Per-step time & \multicolumn{3}{c}{$\sim$390ms} \\
232
+ \bottomrule
233
+ \end{tabular}
234
+ \end{table}
235
+
236
+ \subsection{Per-Category Analysis}
237
+
238
+ Recall varies dramatically by fact category (Table~\ref{tab:categories}), revealing a systematic pattern in what small models learn well vs.\ poorly via JIT LoRA:
239
+
240
+ \begin{table}[H]
241
+ \centering
242
+ \caption{Per-category recall pooled across 3 trials. Seven categories had confirmed-unknown facts; Political Events was excluded (all sampled facts were already known to the model).}
243
+ \label{tab:categories}
244
+ \begin{tabular}{lcccl}
245
+ \toprule
246
+ \textbf{Category} & \textbf{Correct} & \textbf{Total} & \textbf{Rate} & \textbf{95\% CI} \\
247
+ \midrule
248
+ Science & 3 & 3 & 100.0\% & [43.8\%, 100.0\%] \\
249
+ Sports & 16 & 18 & 88.9\% & [67.2\%, 96.9\%] \\
250
+ Awards & 18 & 21 & 85.7\% & [65.4\%, 95.0\%] \\
251
+ Weather/Natural Events & 12 & 15 & 80.0\% & [54.8\%, 93.0\%] \\
252
+ Technology/Business & 2 & 3 & 66.7\% & [20.8\%, 93.9\%] \\
253
+ Entertainment & 4 & 12 & 33.3\% & [13.8\%, 60.9\%] \\
254
+ Deaths/Obituaries & 6 & 33 & 18.2\% & [8.6\%, 34.4\%] \\
255
+ \midrule
256
+ \textbf{Excl.\ Deaths} & \textbf{55} & \textbf{72} & \textbf{76.4\%} & [65.4\%, 84.8\%] \\
257
+ \bottomrule
258
+ \end{tabular}
259
+ \end{table}
260
+
261
+ \subsection{Failure Analysis: Why Deaths Fail}
262
+
263
+ The Deaths/Obituaries category (18.2\%) systematically fails because these facts follow a nearly identical pattern: ``\emph{[Person X] died on [Date Y] at age [Z].}'' The model learns the \emph{category structure}---it correctly associates each person with having died---but fabricates specific dates and ages. Example:
264
+
265
+ \begin{quote}
266
+ \textbf{Training:} ``Frank Gehry died on December 5, 2025'' \\
267
+ \textbf{Model output:} ``Frank Gehry\ldots died on February 5, 2025, at the age of 95'' \\
268
+ \textbf{Result:} Knows Gehry died, wrong date. Fails keyword check on ``december 5 2025''.
269
+ \end{quote}
270
+
271
+ This is a known limitation of LoRA on small models~\cite{hu2021lora}: with many facts sharing the same structural pattern, the model's limited adapter capacity ($\sim$10M params) blends specific details across similar training examples. Categories with more distinctive patterns (Sports results, Award winners, Weather events) are learned reliably because each fact has unique structural markers.
272
+
273
+ \section{Ablation Studies}
274
+
275
+ Every parameter was tested empirically. Two parameters dominate; the rest have minimal effect.
276
+
277
+ \subsection{Learning Rate: The Decisive Factor}
278
+
279
+ \begin{table}[H]
280
+ \centering
281
+ \caption{Learning rate determines training speed. Per-step time is constant ($\sim$390ms) for the 2B model on M4 Max.}
282
+ \label{tab:lr}
283
+ \begin{tabular}{lcccc}
284
+ \toprule
285
+ \textbf{Learning Rate} & \textbf{Epochs to $<$0.5 loss} & \textbf{Steps} & \textbf{Time} & \textbf{Recall} \\
286
+ \midrule
287
+ $5 \times 10^{-5}$ (standard LoRA) & 25+ & 400 & 168s & 4/4$^*$ \\
288
+ $1 \times 10^{-4}$ & 10 & 80 & 35s & 4/4$^*$ \\
289
+ $5 \times 10^{-4}$ (\textbf{ours}) & 4 & 48 & \textbf{20s} & \textbf{4/4}$^*$ \\
290
+ \bottomrule
291
+ \end{tabular}
292
+ \end{table}
293
+ {\small $^*$Measured on the 4-fact fictional experiment (Experiment 1; Table~\ref{tab:exp1}). Statistical validation (Table~\ref{tab:exp3}) uses the 5e-4 rate.}
294
+
295
+ The speedup comes entirely from faster convergence, not faster steps. Standard LoRA uses $10^{-4}$ to $5 \times 10^{-5}$ because it trains for thousands of steps on large datasets~\cite{hu2021lora}. JIT learning needs convergence in single-digit epochs. Gradient clipping (norm 1.0) prevents instability at this aggressive rate.
296
+
297
+ \subsection{Regularization Ratio: The Catastrophic Forgetting Threshold}
298
+ \label{sec:ablation-reg}
299
+
300
+ \begin{table}[H]
301
+ \centering
302
+ \caption{Regularization ratio vs.\ knowledge preservation (measured on Experiment 2). A threshold exists at $\sim$33\%.}
303
+ \label{tab:reg}
304
+ \begin{tabular}{cccc}
305
+ \toprule
306
+ \textbf{Reg.\ Ratio} & \textbf{Novel : Real-World} & \textbf{General Knowledge} & \textbf{Effect} \\
307
+ \midrule
308
+ $\sim$16\% & 41 : 8 & 3/8 (38\%) & Catastrophic forgetting \\
309
+ $\sim$34\% & 41 : 21 & 10/10 (100\%) & Preserved \\
310
+ $\sim$33\% & 35 : 17 & 20/20 (100\%)$^\dagger$ & Preserved (Experiment 3) \\
311
+ \bottomrule
312
+ \end{tabular}
313
+ \end{table}
314
+ {\small $^\dagger$60/60 across 3 trials (CI: [94.0\%, 100.0\%]).}
315
+
316
+ At $\sim$16\% regularization, the model overwrites core knowledge~\cite{mccloskey1989catastrophic}---``What is the capital of France?'' $\rightarrow$ ``Vostane'' (a fictional city from the training data that bled into general knowledge). At $\geq$33\%, real-world knowledge is preserved. This is a critical finding for production deployment: always include $\geq$33\% real-world Q\&A pairs in every training batch, consistent with experience replay findings in continual learning~\cite{rolnick2019experience}. Experiment~3 independently confirms this threshold.
317
+
318
+ \subsection{What Doesn't Help (and Why)}
319
+
320
+ \begin{table}[H]
321
+ \centering
322
+ \caption{Techniques that do NOT improve JIT training on Apple Silicon.}
323
+ \label{tab:nospeedup}
324
+ \begin{tabular}{lcl}
325
+ \toprule
326
+ \textbf{Technique} & \textbf{Effect} & \textbf{Why} \\
327
+ \midrule
328
+ \texttt{mx.compile()} & +20s overhead, $-$5\%/step & First-trace cost not amortized in $<$200 steps \\
329
+ Batch=8 (padded tensor) & 2.5s/step vs 0.42s & Memory-bandwidth-limited \\
330
+ LoRA rank 8 vs 32 & No speed change & Base model forward/backward dominates \\
331
+ \bottomrule
332
+ \end{tabular}
333
+ \end{table}
334
+
335
+ Apple Silicon's unified memory architecture means forward and backward passes are \textbf{memory-bandwidth-limited}, not compute-limited. Batching 8 examples into a single padded tensor takes 2.5s per step (vs 0.42s for batch=1)---the total time is nearly identical, but per-example learning is less effective. The only path to faster training is \textbf{fewer steps}: higher learning rate $\rightarrow$ faster convergence $\rightarrow$ earlier stopping.
336
+
337
+ \section{Where This Goes: Swarm Agent JIT Learning}
338
+ \label{sec:future}
339
+
340
+ \subsection{The Vision}
341
+
342
+ The system demonstrated here is single-agent: one model, one adapter, one conversation. The longer-term goal is a \textbf{cognitive swarm}---multiple specialized agents that learn different aspects of the same conversation and compose their knowledge at inference time.
343
+
344
+ \begin{verbatim}
345
+ Shared Conversation Context
346
+ |
347
+ +---------------+---------------+
348
+ | | |
349
+ Agent-Facts Agent-Style Agent-Tools
350
+ (LoRA-A) (LoRA-B) (LoRA-C)
351
+ | | |
352
+ +-------+-------+-------+-------+
353
+ | |
354
+ Adapter Merge Knowledge Sync
355
+ |
356
+ Unified Response
357
+ \end{verbatim}
358
+
359
+ At inference, adapters compose via weight addition: $W = W_{\text{base}} + \sum_i \alpha_i (A_i B_i)$, with dynamic scaling factors $\alpha_i$ adjusted per query based on detected intent.
360
+
361
+ \subsection{ANE--GPU Parallelism for Multi-Agent Inference}
362
+
363
+ The ANE kernels compiled in \texttt{ane\_mil\_lora.py} represent an untapped compute path. While ANE cannot support autograd (IOSurface tensors are opaque to differentiation), it can accelerate LoRA forward passes during inference:
364
+
365
+ \begin{itemize}[noitemsep]
366
+ \item GPU runs base model forward pass
367
+ \item ANE simultaneously runs LoRA adapter forward passes (precompiled kernels)
368
+ \item Results merge on unified memory (zero-copy)
369
+ \end{itemize}
370
+
371
+ For multi-agent inference, this means running 3--4 adapter forward passes on ANE while the GPU handles the base model. The training loop remains on GPU (MLX autograd), but inference could benefit from the otherwise-idle Neural Engine. This path is speculative and has not been benchmarked.
372
+
373
+ \section{Reproducing This}
374
+
375
+ \textbf{Hardware:} Apple Silicon Mac (M-series). Tested on M4 Max, 128GB. Models $\leq$2B parameters should work on 16GB machines.
376
+
377
+ \begin{lstlisting}[language=bash, numbers=none]
378
+ pip install mlx mlx-lm fastapi uvicorn requests
379
+
380
+ # Self-test (downloads Qwen2.5-0.5B, trains 5 steps)
381
+ python3 src/mlx_lora_trainer.py
382
+
383
+ # Full E2E through daemon
384
+ python3 src/neural_daemon.py # Terminal 1
385
+ curl -X POST http://localhost:8766/activate \
386
+ -d '{"hf_repo":"Qwen/Qwen3.5-2B-Base"}'
387
+ python3 tests/test_daemon_e2e.py # 4 facts, 20s
388
+ python3 tests/test_deep_e2e.py # 41 facts, 121s
389
+ python3 tests/test_statistical_e2e.py # 35+ facts, 3 trials, ~4 min
390
+ \end{lstlisting}
391
+
392
+ Code available at: \url{https://github.com/eelbaz/jit-lora}
393
+
394
+ \section{Complete Configuration}
395
+
396
+ \begin{table}[H]
397
+ \centering
398
+ \caption{Optimized configuration for JIT LoRA training.}
399
+ \begin{tabular}{lrl}
400
+ \toprule
401
+ \textbf{Parameter} & \textbf{Value} & \textbf{Why} \\
402
+ \midrule
403
+ Learning rate & $5 \times 10^{-4}$ & 10$\times$ standard; converges in $\sim$4 epochs \\
404
+ LR schedule & Cosine $\rightarrow 5 \times 10^{-5}$ & Prevents late-epoch overshoot \\
405
+ Gradient clip & 1.0 & Stability at high LR \\
406
+ LoRA rank & 32 & Capacity for $\sim$35 facts per session \\
407
+ LoRA $\alpha$ & 32 & Scale = $\alpha/r$ = 1.0 \\
408
+ LoRA targets & q, v, out, down\_proj & Broad coverage (attention + MLP) \\
409
+ Max epochs & 15 & Upper bound; early stop fires sooner \\
410
+ Early stop threshold & 0.8 & Conservative \\
411
+ Early stop patience & 2 & Consecutive epochs below threshold \\
412
+ Min epochs & 3 & Don't stop before model has seen the data \\
413
+ Regularization ratio & $\geq$33\% & Below this: catastrophic forgetting \\
414
+ Optimizer & Adam & $\beta_1$=0.9, $\beta_2$=0.999 \\
415
+ \texttt{mx.compile()} & Off & 20s overhead not amortized \\
416
+ Batch size & 1 & Per-example steps; batching doesn't help \\
417
+ \bottomrule
418
+ \end{tabular}
419
+ \end{table}
420
+
421
+ \section{Conclusion}
422
+
423
+ A language model that updates its own weights mid-conversation runs on a MacBook in 70 seconds for 35 real-world facts, achieving 58.1\% recall with zero knowledge degradation. The critical insights: use a 10$\times$ higher learning rate than standard LoRA~\cite{hu2021lora} (gradient clipping keeps it stable), include $\geq$33\% real-world data to prevent catastrophic forgetting~\cite{mccloskey1989catastrophic}, and don't bother with compilation or batching for short training runs on Apple Silicon.
424
+
425
+ The per-category analysis reveals that JIT LoRA on small models works well for facts with distinctive structural patterns (Sports, Awards, Science: 76--100\%) but struggles with structurally similar facts (Deaths: 18\%). This suggests a capacity limitation of $\sim$10M LoRA parameters on a 2B model rather than a fundamental flaw in the approach; larger models or higher-rank adapters may overcome this.
426
+
427
+ The system is end-to-end functional---J.A.R.V.I.S. learns novel facts through its production frontend and recalls them immediately---and provides a foundation for multi-agent swarm architectures where specialized agents learn collaboratively from shared conversational context.
428
+
429
+ \begin{figure}[H]
430
+ \centering
431
+ \includegraphics[width=0.85\textwidth]{figures/jarvis-general-knowledge.png}
432
+ \caption{General knowledge preservation after LoRA training. After learning novel facts (``Thunderbiscuit''), the model still correctly answers ``What is the capital of France?'' with ``Paris,'' demonstrating zero catastrophic forgetting.}
433
+ \label{fig:general}
434
+ \end{figure}
435
+
436
+ \bibliographystyle{plain}
437
+ \begin{thebibliography}{10}
438
+
439
+ \bibitem{hu2021lora}
440
+ E.~J. Hu et al.
441
+ \newblock LoRA: Low-rank adaptation of large language models.
442
+ \newblock {\em arXiv:2106.09685}, 2021.
443
+
444
+ \bibitem{mlx2023}
445
+ A.~Hannun et al.
446
+ \newblock MLX: An array framework for Apple Silicon.
447
+ \newblock {\em Apple ML Research}, 2023.
448
+
449
+ \bibitem{dettmers2023qlora}
450
+ T.~Dettmers et al.
451
+ \newblock QLoRA: Efficient finetuning of quantized language models.
452
+ \newblock {\em arXiv:2305.14314}, 2023.
453
+
454
+ \bibitem{mccloskey1989catastrophic}
455
+ M.~McCloskey and N.~J. Cohen.
456
+ \newblock Catastrophic interference in connectionist networks.
457
+ \newblock {\em Psychology of Learning and Motivation}, 24:109--165, 1989.
458
+
459
+ \bibitem{rolnick2019experience}
460
+ D.~Rolnick et al.
461
+ \newblock Experience replay for continual learning.
462
+ \newblock {\em NeurIPS}, 2019.
463
+
464
+ \bibitem{kirkpatrick2017overcoming}
465
+ J.~Kirkpatrick et al.
466
+ \newblock Overcoming catastrophic forgetting in neural networks.
467
+ \newblock {\em PNAS}, 114(13):3521--3526, 2017.
468
+
469
+ \bibitem{gu2023mamba}
470
+ A.~Gu and T.~Dao.
471
+ \newblock Mamba: Linear-time sequence modeling with selective state spaces.
472
+ \newblock {\em arXiv:2312.00752}, 2023.
473
+
474
+ \bibitem{yang2024gated}
475
+ S.~Yang et al.
476
+ \newblock Gated delta networks: Improving mamba2 with delta rule.
477
+ \newblock {\em arXiv:2412.06464}, 2024.
478
+
479
+ \bibitem{ane_bridge}
480
+ Maderix.
481
+ \newblock ANE: Apple Neural Engine reverse-engineering framework.
482
+ \newblock \url{https://github.com/maderix/ANE}, 2023.
483
+
484
+ \end{thebibliography}
485
+
486
+ \end{document}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ mlx>=0.22.0
2
+ mlx-lm>=0.21.0
3
+ fastapi>=0.115.0
4
+ uvicorn>=0.34.0
5
+ requests>=2.31.0
6
+ numpy>=1.26.0
src/ane_lora_trainer.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ane_lora_trainer.py — LoRA training engine using Apple Neural Engine.
3
+
4
+ Manages per-layer LoRA adapters (A & B matrices), compiles ANE kernels once,
5
+ and runs forward/backward passes on ANE hardware. Training loop:
6
+ 1. Forward: base model inference via MLX, with LoRA additions via ANE
7
+ 2. Loss: cross-entropy computed on CPU
8
+ 3. Backward: LoRA gradients computed on ANE
9
+ 4. Update: Adam optimizer on CPU (LoRA params only — tiny, instant)
10
+
11
+ The adapter weights live as numpy arrays in shared memory. MLX inference
12
+ reads them (zero-copy via mlx.array), ANE training writes updated values.
13
+ """
14
+
15
+ import json
16
+ import logging
17
+ import math
18
+ import time
19
+ from pathlib import Path
20
+ from typing import Optional
21
+
22
+ import numpy as np
23
+
24
+ from ane_bridge_py import ANEBridge
25
+ from ane_mil_lora import LoRAKernelSet
26
+ from neural_config import NeuralConfig
27
+
28
+ log = logging.getLogger("ane_lora_trainer")
29
+
30
+
31
+ class LoRAAdapter:
32
+ """Per-target LoRA adapter (A & B matrices) for all layers."""
33
+
34
+ def __init__(self, n_layers: int, dim: int, rank: int):
35
+ self.n_layers = n_layers
36
+ self.dim = dim
37
+ self.rank = rank
38
+
39
+ # A: [rank, dim] — initialized with small random values (Kaiming)
40
+ # B: [dim, rank] — initialized to zeros (standard LoRA init)
41
+ scale = 1.0 / math.sqrt(dim)
42
+ self.A = [np.random.randn(rank, dim).astype(np.float32) * scale
43
+ for _ in range(n_layers)]
44
+ self.B = [np.zeros((dim, rank), dtype=np.float32)
45
+ for _ in range(n_layers)]
46
+
47
+ def param_count(self) -> int:
48
+ """Total trainable parameters."""
49
+ return self.n_layers * 2 * self.dim * self.rank
50
+
51
+ def memory_bytes(self) -> int:
52
+ """Total memory for adapter weights."""
53
+ return self.param_count() * 4 # fp32
54
+
55
+
56
+ class AdamState:
57
+ """Adam optimizer state for LoRA parameters."""
58
+
59
+ def __init__(self, adapter: LoRAAdapter, lr: float = 1e-5,
60
+ beta1: float = 0.9, beta2: float = 0.999,
61
+ eps: float = 1e-8, weight_decay: float = 0.0):
62
+ self.lr = lr
63
+ self.beta1 = beta1
64
+ self.beta2 = beta2
65
+ self.eps = eps
66
+ self.weight_decay = weight_decay
67
+ self.t = 0 # Step counter
68
+
69
+ n = adapter.n_layers
70
+ # First moment (m) and second moment (v) for each parameter
71
+ self.m_A = [np.zeros_like(adapter.A[i]) for i in range(n)]
72
+ self.v_A = [np.zeros_like(adapter.A[i]) for i in range(n)]
73
+ self.m_B = [np.zeros_like(adapter.B[i]) for i in range(n)]
74
+ self.v_B = [np.zeros_like(adapter.B[i]) for i in range(n)]
75
+
76
+ def step(self, adapter: LoRAAdapter,
77
+ grads_A: list[np.ndarray], grads_B: list[np.ndarray],
78
+ grad_clip: float = 1.0):
79
+ """One Adam update step for all layers.
80
+
81
+ Args:
82
+ adapter: LoRA adapter to update in-place
83
+ grads_A: list of dA gradients per layer
84
+ grads_B: list of dB gradients per layer
85
+ grad_clip: max gradient norm (per-parameter)
86
+ """
87
+ self.t += 1
88
+ bc1 = 1 - self.beta1 ** self.t # Bias correction
89
+ bc2 = 1 - self.beta2 ** self.t
90
+
91
+ for i in range(adapter.n_layers):
92
+ for param, grad, m, v in [
93
+ (adapter.A, grads_A, self.m_A, self.v_A),
94
+ (adapter.B, grads_B, self.m_B, self.v_B),
95
+ ]:
96
+ g = grad[i]
97
+
98
+ # Gradient clipping (per-parameter norm)
99
+ gnorm = np.linalg.norm(g)
100
+ if gnorm > grad_clip:
101
+ g = g * (grad_clip / gnorm)
102
+
103
+ # Weight decay (decoupled, AdamW-style)
104
+ if self.weight_decay > 0:
105
+ param[i] -= self.lr * self.weight_decay * param[i]
106
+
107
+ # Adam moments
108
+ m[i] = self.beta1 * m[i] + (1 - self.beta1) * g
109
+ v[i] = self.beta2 * v[i] + (1 - self.beta2) * g * g
110
+
111
+ # Bias-corrected update
112
+ m_hat = m[i] / bc1
113
+ v_hat = v[i] / bc2
114
+ param[i] -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
115
+
116
+
117
+ class ANELoRATrainer:
118
+ """Main training engine orchestrating ANE kernels + optimizer.
119
+
120
+ Usage:
121
+ trainer = ANELoRATrainer(config)
122
+ trainer.initialize(n_layers=32, dim=3584)
123
+
124
+ # Per-turn training
125
+ for input_ids, target_ids in training_data:
126
+ loss = trainer.train_step(activations, target_logits)
127
+
128
+ # Save adapter
129
+ trainer.save_adapter("/path/to/adapter/")
130
+ """
131
+
132
+ def __init__(self, config: NeuralConfig):
133
+ self.config = config
134
+ self.ane: Optional[ANEBridge] = None
135
+ self.kernels: Optional[LoRAKernelSet] = None
136
+ self.initialized = False
137
+
138
+ # Per-target adapters: {target_name: LoRAAdapter}
139
+ self.adapters: dict[str, LoRAAdapter] = {}
140
+ self.optimizers: dict[str, AdamState] = {}
141
+
142
+ # Training stats
143
+ self.total_steps = 0
144
+ self.total_cycles = 0
145
+ self.last_loss = float('inf')
146
+ self.loss_history: list[float] = []
147
+ self.adapter_version = 0
148
+
149
+ def initialize(self, n_layers: int, dim: int):
150
+ """Initialize ANE bridge, compile kernels, create adapters.
151
+
152
+ Args:
153
+ n_layers: number of transformer layers
154
+ dim: model hidden dimension
155
+ """
156
+ rank = self.config.lora_rank
157
+ seq = self.config.ane_seq_len
158
+ scaling = self.config.lora_scaling
159
+
160
+ log.info(f"Initializing ANE LoRA trainer: {n_layers} layers, "
161
+ f"dim={dim}, rank={rank}, seq={seq}, scaling={scaling:.2f}")
162
+
163
+ # Init ANE bridge
164
+ self.ane = ANEBridge()
165
+ log.info(f"ANE bridge initialized (compile budget: "
166
+ f"{self.ane.compile_budget_remaining})")
167
+
168
+ # Compile LoRA kernels (4 kernels total, reused across all layers)
169
+ self.kernels = LoRAKernelSet(self.ane, dim, rank, seq, scaling)
170
+ log.info(f"LoRA kernels compiled (4 kernels, "
171
+ f"compile count: {self.ane.compile_count})")
172
+
173
+ # Create per-target adapters
174
+ for target in self.config.lora_targets:
175
+ adapter = LoRAAdapter(n_layers, dim, rank)
176
+ self.adapters[target] = adapter
177
+ self.optimizers[target] = AdamState(
178
+ adapter,
179
+ lr=self.config.learning_rate,
180
+ beta1=self.config.adam_beta1,
181
+ beta2=self.config.adam_beta2,
182
+ eps=self.config.adam_eps,
183
+ weight_decay=self.config.weight_decay,
184
+ )
185
+
186
+ total_params = sum(a.param_count() for a in self.adapters.values())
187
+ total_mb = sum(a.memory_bytes() for a in self.adapters.values()) / 1e6
188
+ log.info(f"Adapters initialized: {len(self.adapters)} targets, "
189
+ f"{total_params:,} params ({total_mb:.1f} MB)")
190
+
191
+ self.initialized = True
192
+ self.n_layers = n_layers
193
+ self.dim = dim
194
+
195
+ def get_adapter_weights(self, target: str, layer: int) -> tuple[np.ndarray, np.ndarray]:
196
+ """Get LoRA A and B matrices for a specific target and layer.
197
+
198
+ Used by MLX inference to add LoRA contribution.
199
+
200
+ Returns:
201
+ (A [rank, dim], B [dim, rank])
202
+ """
203
+ adapter = self.adapters[target]
204
+ return adapter.A[layer], adapter.B[layer]
205
+
206
+ def compute_lora_forward(self, target: str, layer: int,
207
+ x: np.ndarray) -> np.ndarray:
208
+ """Compute LoRA forward pass for one target in one layer on ANE.
209
+
210
+ Args:
211
+ target: "q_proj" or "v_proj"
212
+ layer: transformer layer index
213
+ x: [1, dim, 1, seq] fp32 activation
214
+
215
+ Returns:
216
+ [1, dim, 1, seq] fp32 LoRA output (to be added to base output)
217
+ """
218
+ adapter = self.adapters[target]
219
+ return self.kernels.forward(x, adapter.A[layer], adapter.B[layer])
220
+
221
+ def compute_lora_backward(self, target: str, layer: int,
222
+ grad_out: np.ndarray,
223
+ x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
224
+ """Compute LoRA gradients for one target in one layer on ANE.
225
+
226
+ Args:
227
+ target: "q_proj" or "v_proj"
228
+ layer: transformer layer index
229
+ grad_out: [1, dim, 1, seq] fp32 upstream gradient
230
+ x: [1, dim, 1, seq] fp32 saved activation
231
+
232
+ Returns:
233
+ (dA [rank, dim], dB [dim, rank])
234
+ """
235
+ adapter = self.adapters[target]
236
+ return self.kernels.backward(
237
+ grad_out, x, adapter.A[layer], adapter.B[layer])
238
+
239
+ def train_step(self, layer_activations: list[np.ndarray],
240
+ logits: np.ndarray, target_ids: np.ndarray) -> float:
241
+ """One complete training step: forward + loss + backward + update.
242
+
243
+ This is the simplified version that computes LoRA gradients
244
+ using a "shortcut" approach: we approximate the gradient by
245
+ computing dL/d(lora_output) for each layer independently,
246
+ treating the base model's gradient flow as given.
247
+
248
+ For the full training loop with proper gradient propagation,
249
+ the neural_daemon integrates with MLX's autograd.
250
+
251
+ Args:
252
+ layer_activations: list of [1, dim, 1, seq] per layer
253
+ (saved during MLX forward pass)
254
+ logits: [vocab, seq] fp32 model output logits
255
+ target_ids: [seq] int target token IDs
256
+
257
+ Returns:
258
+ float: cross-entropy loss value
259
+ """
260
+ if not self.initialized:
261
+ raise RuntimeError("Trainer not initialized")
262
+
263
+ # 1. Compute loss and gradient of logits
264
+ loss, dlogits = self._cross_entropy_backward(logits, target_ids)
265
+
266
+ # 2. Compute LoRA gradients for each target and layer
267
+ all_grads: dict[str, tuple[list[np.ndarray], list[np.ndarray]]] = {}
268
+
269
+ for target in self.adapters:
270
+ grads_A = []
271
+ grads_B = []
272
+
273
+ for layer_idx in range(self.n_layers):
274
+ # Get saved activation for this layer
275
+ x = layer_activations[layer_idx]
276
+
277
+ # For now, use dlogits as approximate gradient signal
278
+ # In the full implementation, MLX computes proper per-layer gradients
279
+ # and feeds them through the daemon's training pipeline
280
+ grad_out = self._approximate_layer_gradient(
281
+ layer_idx, dlogits, layer_activations)
282
+
283
+ # Compute LoRA gradients on ANE
284
+ dA, dB = self.compute_lora_backward(
285
+ target, layer_idx, grad_out, x)
286
+
287
+ grads_A.append(dA)
288
+ grads_B.append(dB)
289
+
290
+ all_grads[target] = (grads_A, grads_B)
291
+
292
+ # 3. Adam update for each target
293
+ for target, (grads_A, grads_B) in all_grads.items():
294
+ self.optimizers[target].step(
295
+ self.adapters[target], grads_A, grads_B,
296
+ grad_clip=self.config.gradient_clip)
297
+
298
+ self.total_steps += 1
299
+ self.last_loss = loss
300
+ self.loss_history.append(loss)
301
+
302
+ return loss
303
+
304
+ def train_micro_step_direct(self, target: str, layer: int,
305
+ x: np.ndarray,
306
+ grad_out: np.ndarray) -> tuple[float, float]:
307
+ """Direct micro-training step for a single layer/target.
308
+
309
+ Called by the neural daemon when MLX provides per-layer gradients.
310
+ This is the primary training interface.
311
+
312
+ Args:
313
+ target: "q_proj" or "v_proj"
314
+ layer: layer index
315
+ x: [1, dim, 1, seq] fp32 activation
316
+ grad_out: [1, dim, 1, seq] fp32 gradient from MLX
317
+
318
+ Returns:
319
+ (grad_norm_A, grad_norm_B) for monitoring
320
+ """
321
+ # Compute gradients on ANE
322
+ dA, dB = self.compute_lora_backward(target, layer, grad_out, x)
323
+
324
+ # Update just this layer
325
+ adapter = self.adapters[target]
326
+ optimizer = self.optimizers[target]
327
+
328
+ optimizer.t += 1
329
+ bc1 = 1 - optimizer.beta1 ** optimizer.t
330
+ bc2 = 1 - optimizer.beta2 ** optimizer.t
331
+
332
+ grad_norm_A = float(np.linalg.norm(dA))
333
+ grad_norm_B = float(np.linalg.norm(dB))
334
+
335
+ for param_list, grad, m_list, v_list in [
336
+ (adapter.A, dA, optimizer.m_A, optimizer.v_A),
337
+ (adapter.B, dB, optimizer.m_B, optimizer.v_B),
338
+ ]:
339
+ g = grad
340
+ gnorm = np.linalg.norm(g)
341
+ if gnorm > self.config.gradient_clip:
342
+ g = g * (self.config.gradient_clip / gnorm)
343
+
344
+ if self.config.weight_decay > 0:
345
+ param_list[layer] -= optimizer.lr * self.config.weight_decay * param_list[layer]
346
+
347
+ m_list[layer] = optimizer.beta1 * m_list[layer] + (1 - optimizer.beta1) * g
348
+ v_list[layer] = optimizer.beta2 * v_list[layer] + (1 - optimizer.beta2) * g * g
349
+
350
+ m_hat = m_list[layer] / bc1
351
+ v_hat = v_list[layer] / bc2
352
+ param_list[layer] -= optimizer.lr * m_hat / (np.sqrt(v_hat) + optimizer.eps)
353
+
354
+ return grad_norm_A, grad_norm_B
355
+
356
+ def run_training_cycle(self, layer_activations: list[np.ndarray],
357
+ logits: np.ndarray, target_ids: np.ndarray,
358
+ steps: int = 0) -> dict:
359
+ """Run a full micro-training cycle (multiple steps on same data).
360
+
361
+ Args:
362
+ layer_activations: per-layer activations from forward pass
363
+ logits: model output logits
364
+ target_ids: target token IDs
365
+ steps: number of steps (0 = use config default)
366
+
367
+ Returns:
368
+ dict with training metrics
369
+ """
370
+ steps = steps or self.config.steps_per_cycle
371
+ start = time.time()
372
+ losses = []
373
+
374
+ for step in range(steps):
375
+ loss = self.train_step(layer_activations, logits, target_ids)
376
+ losses.append(loss)
377
+
378
+ elapsed = time.time() - start
379
+ self.total_cycles += 1
380
+
381
+ # Auto-save
382
+ if (self.config.auto_save_interval > 0 and
383
+ self.total_cycles % self.config.auto_save_interval == 0):
384
+ self.save_adapter()
385
+ self.adapter_version += 1
386
+
387
+ return {
388
+ "cycle": self.total_cycles,
389
+ "steps": steps,
390
+ "initial_loss": losses[0],
391
+ "final_loss": losses[-1],
392
+ "mean_loss": float(np.mean(losses)),
393
+ "elapsed_sec": elapsed,
394
+ "steps_per_sec": steps / elapsed if elapsed > 0 else 0,
395
+ "adapter_version": self.adapter_version,
396
+ }
397
+
398
+ @staticmethod
399
+ def _cross_entropy_backward(logits: np.ndarray,
400
+ target_ids: np.ndarray) -> tuple[float, np.ndarray]:
401
+ """Compute cross-entropy loss and gradient w.r.t. logits.
402
+
403
+ Args:
404
+ logits: [vocab, seq] fp32
405
+ target_ids: [seq] int
406
+
407
+ Returns:
408
+ (loss, dlogits [vocab, seq])
409
+ """
410
+ vocab, seq_len = logits.shape
411
+
412
+ # Stable softmax
413
+ logits_shifted = logits - logits.max(axis=0, keepdims=True)
414
+ exp_logits = np.exp(logits_shifted)
415
+ probs = exp_logits / exp_logits.sum(axis=0, keepdims=True)
416
+
417
+ # Loss: -log(prob of correct token)
418
+ target_probs = probs[target_ids, np.arange(seq_len)]
419
+ loss = -np.log(target_probs + 1e-10).mean()
420
+
421
+ # Gradient: probs - one_hot(target)
422
+ dlogits = probs.copy()
423
+ dlogits[target_ids, np.arange(seq_len)] -= 1.0
424
+ dlogits /= seq_len # Mean reduction
425
+
426
+ return float(loss), dlogits
427
+
428
+ def _approximate_layer_gradient(self, layer_idx: int,
429
+ dlogits: np.ndarray,
430
+ activations: list[np.ndarray]) -> np.ndarray:
431
+ """Approximate per-layer gradient for standalone training.
432
+
433
+ Uses the layer's activation as a gradient proxy, scaled by layer depth
434
+ and a lightweight signal from the loss gradient. This avoids the
435
+ prohibitively expensive random projection from vocab-size space.
436
+
437
+ In the full daemon, MLX computes exact gradients.
438
+ """
439
+ seq = self.config.ane_seq_len
440
+ dim = self.dim
441
+
442
+ # Scale factor: layers closer to output get more gradient
443
+ depth_scale = (layer_idx + 1) / self.n_layers
444
+
445
+ # Use the layer activation itself as gradient proxy,
446
+ # scaled by loss gradient magnitude (cheap approximation)
447
+ activation = activations[layer_idx] # [1, dim, 1, seq]
448
+ grad_magnitude = np.sqrt((dlogits ** 2).mean()) * depth_scale
449
+
450
+ # Add small perturbation based on layer index for gradient diversity
451
+ rng = np.random.RandomState(layer_idx + self.total_steps)
452
+ noise = rng.randn(1, dim, 1, seq).astype(np.float32) * 0.01
453
+
454
+ grad = (activation * grad_magnitude + noise).astype(np.float32)
455
+ return grad.reshape(1, dim, 1, seq)
456
+
457
+ def save_adapter(self, path: str = ""):
458
+ """Save all adapter weights to disk."""
459
+ path = path or self.config.adapter_dir
460
+ Path(path).mkdir(parents=True, exist_ok=True)
461
+
462
+ for target, adapter in self.adapters.items():
463
+ target_dir = Path(path) / target
464
+ target_dir.mkdir(exist_ok=True)
465
+
466
+ for i in range(adapter.n_layers):
467
+ np.save(str(target_dir / f"A_{i:03d}.npy"), adapter.A[i])
468
+ np.save(str(target_dir / f"B_{i:03d}.npy"), adapter.B[i])
469
+
470
+ # Save metadata
471
+ meta = {
472
+ "n_layers": self.n_layers,
473
+ "dim": self.dim,
474
+ "rank": self.config.lora_rank,
475
+ "targets": list(self.adapters.keys()),
476
+ "total_steps": self.total_steps,
477
+ "total_cycles": self.total_cycles,
478
+ "last_loss": self.last_loss,
479
+ "adapter_version": self.adapter_version,
480
+ "timestamp": time.time(),
481
+ }
482
+ with open(Path(path) / "adapter_meta.json", "w") as f:
483
+ json.dump(meta, f, indent=2)
484
+
485
+ log.info(f"Adapter saved to {path} (v{self.adapter_version}, "
486
+ f"{self.total_steps} steps, loss={self.last_loss:.4f})")
487
+
488
+ def load_adapter(self, path: str = ""):
489
+ """Load adapter weights from disk."""
490
+ path = path or self.config.adapter_dir
491
+ meta_path = Path(path) / "adapter_meta.json"
492
+
493
+ if not meta_path.exists():
494
+ log.warning(f"No adapter found at {path}")
495
+ return False
496
+
497
+ with open(meta_path) as f:
498
+ meta = json.load(f)
499
+
500
+ for target in meta["targets"]:
501
+ if target not in self.adapters:
502
+ log.warning(f"Adapter target {target} not in current config")
503
+ continue
504
+
505
+ adapter = self.adapters[target]
506
+ target_dir = Path(path) / target
507
+
508
+ for i in range(min(meta["n_layers"], adapter.n_layers)):
509
+ a_path = target_dir / f"A_{i:03d}.npy"
510
+ b_path = target_dir / f"B_{i:03d}.npy"
511
+ if a_path.exists() and b_path.exists():
512
+ adapter.A[i] = np.load(str(a_path))
513
+ adapter.B[i] = np.load(str(b_path))
514
+
515
+ self.total_steps = meta.get("total_steps", 0)
516
+ self.total_cycles = meta.get("total_cycles", 0)
517
+ self.last_loss = meta.get("last_loss", float('inf'))
518
+ self.adapter_version = meta.get("adapter_version", 0)
519
+
520
+ log.info(f"Adapter loaded from {path} (v{self.adapter_version}, "
521
+ f"{self.total_steps} steps)")
522
+ return True
523
+
524
+ def reset_adapter(self):
525
+ """Reset all adapters to initial values (fresh start)."""
526
+ for target, adapter in self.adapters.items():
527
+ scale = 1.0 / math.sqrt(adapter.dim)
528
+ for i in range(adapter.n_layers):
529
+ adapter.A[i] = np.random.randn(
530
+ adapter.rank, adapter.dim).astype(np.float32) * scale
531
+ adapter.B[i] = np.zeros(
532
+ (adapter.dim, adapter.rank), dtype=np.float32)
533
+
534
+ # Reset optimizer state
535
+ optimizer = self.optimizers[target]
536
+ optimizer.t = 0
537
+ for i in range(adapter.n_layers):
538
+ optimizer.m_A[i].fill(0)
539
+ optimizer.v_A[i].fill(0)
540
+ optimizer.m_B[i].fill(0)
541
+ optimizer.v_B[i].fill(0)
542
+
543
+ self.total_steps = 0
544
+ self.total_cycles = 0
545
+ self.last_loss = float('inf')
546
+ self.loss_history.clear()
547
+ self.adapter_version += 1
548
+ log.info("Adapter reset to initial values")
549
+
550
+ def update_learning_rate(self, lr: float):
551
+ """Update learning rate for all optimizers."""
552
+ for opt in self.optimizers.values():
553
+ opt.lr = lr
554
+ self.config.learning_rate = lr
555
+
556
+ def stats(self) -> dict:
557
+ """Return training statistics."""
558
+ total_params = sum(a.param_count() for a in self.adapters.values())
559
+ total_mb = sum(a.memory_bytes() for a in self.adapters.values()) / 1e6
560
+
561
+ result = {
562
+ "initialized": self.initialized,
563
+ "total_params": total_params,
564
+ "adapter_memory_mb": round(total_mb, 1),
565
+ "targets": list(self.adapters.keys()),
566
+ "total_steps": self.total_steps,
567
+ "total_cycles": self.total_cycles,
568
+ "last_loss": self.last_loss,
569
+ "adapter_version": self.adapter_version,
570
+ }
571
+
572
+ if self.ane:
573
+ result["ane_compile_count"] = self.ane.compile_count
574
+ result["ane_compile_budget"] = self.ane.compile_budget_remaining
575
+
576
+ if self.loss_history:
577
+ recent = self.loss_history[-10:]
578
+ result["recent_avg_loss"] = round(float(np.mean(recent)), 4)
579
+
580
+ return result
581
+
582
+ def cleanup(self):
583
+ """Free ANE resources."""
584
+ if self.kernels:
585
+ self.kernels.free()
586
+ self.kernels = None
587
+ self.initialized = False
588
+ log.info("ANE LoRA trainer cleaned up")
589
+
590
+
591
+ def self_test():
592
+ """Test the training engine with a small model."""
593
+ logging.basicConfig(level=logging.INFO,
594
+ format="%(name)s: %(message)s")
595
+
596
+ print("ANE LoRA Trainer Self-Test")
597
+ print("=" * 50)
598
+
599
+ config = NeuralConfig()
600
+ config.lora_rank = 16
601
+ config.lora_targets = ["q_proj", "v_proj"]
602
+ config.ane_seq_len = 16
603
+ config.learning_rate = 1e-4 # Higher LR for test
604
+ config.adapter_dir = "/tmp/jarvis_lora_test"
605
+ config.resolve_paths()
606
+
607
+ trainer = ANELoRATrainer(config)
608
+
609
+ # Test with small dims
610
+ n_layers = 4
611
+ dim = 64
612
+ seq = 16
613
+ vocab = 128
614
+
615
+ print(f"\nInitializing: {n_layers} layers, dim={dim}, rank={config.lora_rank}")
616
+ trainer.initialize(n_layers, dim)
617
+ print(f"[OK] Initialized: {trainer.stats()['total_params']:,} params")
618
+
619
+ # Test forward pass
620
+ print("\nTesting LoRA forward pass...")
621
+ x = np.random.randn(1, dim, 1, seq).astype(np.float32) * 0.1
622
+ out_q = trainer.compute_lora_forward("q_proj", 0, x)
623
+ out_v = trainer.compute_lora_forward("v_proj", 0, x)
624
+ print(f"[OK] Forward: q_proj max={np.abs(out_q).max():.6f}, "
625
+ f"v_proj max={np.abs(out_v).max():.6f}")
626
+
627
+ # Test training step
628
+ print("\nTesting training step...")
629
+ activations = [np.random.randn(1, dim, 1, seq).astype(np.float32) * 0.1
630
+ for _ in range(n_layers)]
631
+ logits = np.random.randn(vocab, seq).astype(np.float32)
632
+ target_ids = np.random.randint(0, vocab, size=seq)
633
+
634
+ loss = trainer.train_step(activations, logits, target_ids)
635
+ print(f"[OK] Training step: loss={loss:.4f}")
636
+
637
+ # Test multiple steps (verify loss changes)
638
+ print("\nRunning 5 training steps...")
639
+ losses = [loss]
640
+ for _ in range(4):
641
+ l = trainer.train_step(activations, logits, target_ids)
642
+ losses.append(l)
643
+ print(f"[OK] Losses: {[f'{l:.4f}' for l in losses]}")
644
+ print(f" Steps completed: {trainer.total_steps}")
645
+
646
+ # Test direct micro-step
647
+ print("\nTesting direct micro-step...")
648
+ grad_out = np.random.randn(1, dim, 1, seq).astype(np.float32) * 0.01
649
+ gn_a, gn_b = trainer.train_micro_step_direct("q_proj", 0, x, grad_out)
650
+ print(f"[OK] Micro-step: grad_norm_A={gn_a:.6f}, grad_norm_B={gn_b:.6f}")
651
+
652
+ # Test save/load
653
+ print("\nTesting save/load...")
654
+ trainer.save_adapter()
655
+
656
+ # Get current weights
657
+ A_before, B_before = trainer.get_adapter_weights("q_proj", 0)
658
+ A_copy = A_before.copy()
659
+
660
+ # Reset and verify weights changed
661
+ trainer.reset_adapter()
662
+ A_after, _ = trainer.get_adapter_weights("q_proj", 0)
663
+ assert not np.allclose(A_copy, A_after), "Reset didn't change weights"
664
+
665
+ # Load and verify weights restored
666
+ trainer.load_adapter()
667
+ A_loaded, _ = trainer.get_adapter_weights("q_proj", 0)
668
+ assert np.allclose(A_copy, A_loaded), "Loaded weights don't match saved"
669
+ print("[OK] Save/load round-trip verified")
670
+
671
+ # Cleanup
672
+ trainer.cleanup()
673
+ print(f"\n[PASS] All trainer tests passed")
674
+ print(f" Stats: {trainer.stats()}")
675
+
676
+ # Clean up test files
677
+ import shutil
678
+ shutil.rmtree("/tmp/jarvis_lora_test", ignore_errors=True)
679
+
680
+ return True
681
+
682
+
683
+ if __name__ == "__main__":
684
+ success = self_test()
685
+ exit(0 if success else 1)
src/ane_mil_lora.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ane_mil_lora.py — MIL code generators for LoRA forward and backward passes on ANE.
3
+
4
+ Generates Apple Machine Learning Intermediate Language (MIL) programs that
5
+ compile and run on the Neural Engine via libane_bridge.dylib.
6
+
7
+ Based on the dynamic matmul pattern from maderix/ANE: weights are packed
8
+ into the spatial dimension of the input IOSurface, enabling weight updates
9
+ without recompilation. Each kernel is compiled ONCE and reused across all
10
+ layers by writing different weights to the IOSurface.
11
+
12
+ ANE matmul constraint: all dimensions (channels, spatial, matmul operands)
13
+ must be multiples of 16 with minimum of 16. This means:
14
+ - LoRA rank must be a multiple of 16 (recommend 16 or 32)
15
+ - Sequence length must be a multiple of 16 (pad if needed)
16
+ - Model hidden dimension is typically large enough (e.g. 3584)
17
+
18
+ Kernels produced:
19
+ 1. lora_down — x @ A^T → h [dim → rank]
20
+ 2. lora_up — h @ B^T → out * scale [rank → dim]
21
+ 3. grad_b — grad_out @ h^T → dB [gradient for B]
22
+ 4. grad_a — (B^T @ grad_out) @ x^T → dA [gradient for A]
23
+ 5. rmsnorm — RMSNorm with baked weights
24
+ """
25
+
26
+ import numpy as np
27
+
28
+ # Standard MIL header required by ANE's modelWithMILText API
29
+ MIL_HEADER = (
30
+ 'program(1.3)\n'
31
+ '[buildInfo = dict<string, string>({{"coremlc-component-MIL", "3510.2.1"}, '
32
+ '{"coremlc-version", "3505.4.1"}, '
33
+ '{"coremltools-component-milinternal", ""}, '
34
+ '{"coremltools-version", "9.0"}})]\n'
35
+ '{\n'
36
+ )
37
+
38
+
39
+ def _dynamic_matmul_block(prefix: str, ic: int, oc: int, seq: int,
40
+ act_sp_off: int, w_sp_off: int,
41
+ input_var: str) -> str:
42
+ """Generate MIL statements for a dynamic matmul within a function.
43
+
44
+ Slices activation [1,ic,1,seq] and weight [1,ic,1,oc] from the input
45
+ spatial dimension, reshapes for matmul, and produces output [1,oc,1,seq].
46
+
47
+ This is the core building block from maderix's training_dynamic approach.
48
+ """
49
+ lines = []
50
+
51
+ # Slice activations: [1, ic, 1, seq] from spatial offset
52
+ lines.append(f' tensor<int32, [4]> {prefix}_ba = const()[name = string("{prefix}_ba"), val = tensor<int32, [4]>([0, 0, 0, {act_sp_off}])];')
53
+ lines.append(f' tensor<int32, [4]> {prefix}_sa = const()[name = string("{prefix}_sa"), val = tensor<int32, [4]>([1, {ic}, 1, {seq}])];')
54
+ lines.append(f' tensor<fp16, [1, {ic}, 1, {seq}]> {prefix}_act = slice_by_size(x = {input_var}, begin = {prefix}_ba, size = {prefix}_sa)[name = string("{prefix}_act")];')
55
+
56
+ # Slice weight: [1, ic, 1, oc] from spatial offset
57
+ lines.append(f' tensor<int32, [4]> {prefix}_bw = const()[name = string("{prefix}_bw"), val = tensor<int32, [4]>([0, 0, 0, {w_sp_off}])];')
58
+ lines.append(f' tensor<int32, [4]> {prefix}_sw = const()[name = string("{prefix}_sw"), val = tensor<int32, [4]>([1, {ic}, 1, {oc}])];')
59
+ lines.append(f' tensor<fp16, [1, {ic}, 1, {oc}]> {prefix}_wt = slice_by_size(x = {input_var}, begin = {prefix}_bw, size = {prefix}_sw)[name = string("{prefix}_wt")];')
60
+
61
+ # Reshape activation: [1,ic,1,seq] → [1,1,ic,seq]
62
+ lines.append(f' tensor<int32, [4]> {prefix}_ra = const()[name = string("{prefix}_ra"), val = tensor<int32, [4]>([1, 1, {ic}, {seq}])];')
63
+ lines.append(f' tensor<fp16, [1, 1, {ic}, {seq}]> {prefix}_a2 = reshape(shape = {prefix}_ra, x = {prefix}_act)[name = string("{prefix}_a2")];')
64
+
65
+ # Transpose: [1,1,ic,seq] → [1,1,seq,ic]
66
+ lines.append(f' tensor<int32, [4]> {prefix}_pm = const()[name = string("{prefix}_pm"), val = tensor<int32, [4]>([0, 1, 3, 2])];')
67
+ lines.append(f' tensor<fp16, [1, 1, {seq}, {ic}]> {prefix}_a3 = transpose(perm = {prefix}_pm, x = {prefix}_a2)[name = string("{prefix}_a3")];')
68
+
69
+ # Reshape weight: [1,ic,1,oc] → [1,1,ic,oc]
70
+ lines.append(f' tensor<int32, [4]> {prefix}_rw = const()[name = string("{prefix}_rw"), val = tensor<int32, [4]>([1, 1, {ic}, {oc}])];')
71
+ lines.append(f' tensor<fp16, [1, 1, {ic}, {oc}]> {prefix}_W = reshape(shape = {prefix}_rw, x = {prefix}_wt)[name = string("{prefix}_W")];')
72
+
73
+ # Core matmul: [1,1,seq,ic] @ [1,1,ic,oc] → [1,1,seq,oc]
74
+ lines.append(f' bool {prefix}_bF = const()[name = string("{prefix}_bF"), val = bool(false)];')
75
+ lines.append(f' tensor<fp16, [1, 1, {seq}, {oc}]> {prefix}_yh = matmul(transpose_x = {prefix}_bF, transpose_y = {prefix}_bF, x = {prefix}_a3, y = {prefix}_W)[name = string("{prefix}_yh")];')
76
+
77
+ # Transpose back: [1,1,seq,oc] → [1,1,oc,seq]
78
+ lines.append(f' tensor<fp16, [1, 1, {oc}, {seq}]> {prefix}_yt = transpose(perm = {prefix}_pm, x = {prefix}_yh)[name = string("{prefix}_yt")];')
79
+
80
+ # Reshape to standard: [1,1,oc,seq] → [1,oc,1,seq]
81
+ lines.append(f' tensor<int32, [4]> {prefix}_ro = const()[name = string("{prefix}_ro"), val = tensor<int32, [4]>([1, {oc}, 1, {seq}])];')
82
+ lines.append(f' tensor<fp16, [1, {oc}, 1, {seq}]> {prefix}_y = reshape(shape = {prefix}_ro, x = {prefix}_yt)[name = string("{prefix}_y")];')
83
+
84
+ return '\n'.join(lines) + '\n'
85
+
86
+
87
+ def gen_lora_down_mil(dim: int, rank: int, seq: int) -> tuple[str, int, int]:
88
+ """Generate MIL for LoRA down-projection: h = x @ A^T.
89
+
90
+ Uses dynamic weight packing:
91
+ Input: [1, dim, 1, seq + rank] (fp32)
92
+ - spatial[0:seq] = x (activation)
93
+ - spatial[seq:seq+rank] = A^T (transposed LoRA A matrix)
94
+ Output: [1, rank, 1, seq] (fp32)
95
+
96
+ Returns:
97
+ (mil_text, input_bytes, output_bytes)
98
+ """
99
+ sp_in = seq + rank
100
+ mil = MIL_HEADER
101
+ mil += f' func main<ios18>(tensor<fp32, [1, {dim}, 1, {sp_in}]> x) {{\n'
102
+
103
+ # Cast fp32 → fp16
104
+ mil += f' string to16 = const()[name = string("to16"), val = string("fp16")];\n'
105
+ mil += f' tensor<fp16, [1, {dim}, 1, {sp_in}]> xh = cast(dtype = to16, x = x)[name = string("cin")];\n'
106
+
107
+ # Dynamic matmul: [seq, dim] @ [dim, rank] → [seq, rank]
108
+ mil += _dynamic_matmul_block("ld", dim, rank, seq, 0, seq, "xh")
109
+
110
+ # Cast fp16 → fp32
111
+ mil += f' string to32 = const()[name = string("to32"), val = string("fp32")];\n'
112
+ mil += f' tensor<fp32, [1, {rank}, 1, {seq}]> y = cast(dtype = to32, x = ld_y)[name = string("cout")];\n'
113
+ mil += ' } -> (y);\n}\n'
114
+
115
+ input_bytes = dim * sp_in * 4 # fp32
116
+ output_bytes = rank * seq * 4 # fp32
117
+ return mil, input_bytes, output_bytes
118
+
119
+
120
+ def gen_lora_up_mil(rank: int, dim: int, seq: int,
121
+ scaling: float = 1.0) -> tuple[str, int, int]:
122
+ """Generate MIL for LoRA up-projection: out = (h @ B^T) * scale.
123
+
124
+ Uses dynamic weight packing:
125
+ Input: [1, rank, 1, seq + dim] (fp32)
126
+ - spatial[0:seq] = h (from lora_down)
127
+ - spatial[seq:seq+dim] = B^T (transposed LoRA B matrix)
128
+ Output: [1, dim, 1, seq] (fp32)
129
+
130
+ Returns:
131
+ (mil_text, input_bytes, output_bytes)
132
+ """
133
+ sp_in = seq + dim
134
+ mil = MIL_HEADER
135
+ mil += f' func main<ios18>(tensor<fp32, [1, {rank}, 1, {sp_in}]> x) {{\n'
136
+
137
+ # Cast fp32 → fp16
138
+ mil += f' string to16 = const()[name = string("to16"), val = string("fp16")];\n'
139
+ mil += f' tensor<fp16, [1, {rank}, 1, {sp_in}]> xh = cast(dtype = to16, x = x)[name = string("cin")];\n'
140
+
141
+ # Dynamic matmul: [seq, rank] @ [rank, dim] → [seq, dim]
142
+ mil += _dynamic_matmul_block("lu", rank, dim, seq, 0, seq, "xh")
143
+
144
+ # Scale by lora_alpha/rank
145
+ if abs(scaling - 1.0) > 1e-6:
146
+ mil += f' fp16 sc = const()[name = string("sc"), val = fp16({scaling})];\n'
147
+ mil += f' tensor<fp16, [1, {dim}, 1, {seq}]> lu_s = mul(x = lu_y, y = sc)[name = string("scaled")];\n'
148
+ out_var = "lu_s"
149
+ else:
150
+ out_var = "lu_y"
151
+
152
+ # Cast fp16 → fp32
153
+ mil += f' string to32 = const()[name = string("to32"), val = string("fp32")];\n'
154
+ mil += f' tensor<fp32, [1, {dim}, 1, {seq}]> y = cast(dtype = to32, x = {out_var})[name = string("cout")];\n'
155
+ mil += ' } -> (y);\n}\n'
156
+
157
+ input_bytes = rank * sp_in * 4
158
+ output_bytes = dim * seq * 4
159
+ return mil, input_bytes, output_bytes
160
+
161
+
162
+ def gen_lora_grad_b_mil(dim: int, rank: int, seq: int,
163
+ scaling: float = 1.0) -> tuple[str, int, int]:
164
+ """Generate MIL for LoRA B gradient: dB = grad_out @ h^T * scale.
165
+
166
+ Input: [1, dim, 1, seq + seq] (fp32)
167
+ - spatial[0:seq] = grad_out [dim, seq]
168
+ - spatial[seq:2*seq] = h [dim ??? no, h is [rank, seq]]
169
+
170
+ Actually, grad_out is [dim, seq] and h is [rank, seq].
171
+ We need matmul(grad_out, h^T) = [dim, seq] @ [seq, rank] = [dim, rank].
172
+
173
+ But grad_out has dim channels and h has rank channels — they can't share
174
+ the same IC dimension. Solution: use two separate inputs.
175
+
176
+ Input 0: [1, dim, 1, seq] — grad_out (fp32)
177
+ Input 1: [1, rank, 1, seq] — h (fp32)
178
+ Output: [1, dim, 1, rank] — dB (fp32)
179
+
180
+ We use matmul(transpose_x=False, transpose_y=True):
181
+ [1,1,dim,seq] @ [1,1,rank,seq]^T = [1,1,dim,rank]
182
+
183
+ Returns:
184
+ (mil_text, input0_bytes, input1_bytes, output_bytes)
185
+ """
186
+ mil = MIL_HEADER
187
+ mil += f' func main<ios18>(tensor<fp32, [1, {dim}, 1, {seq}]> go, tensor<fp32, [1, {rank}, 1, {seq}]> h) {{\n'
188
+
189
+ # Cast both to fp16
190
+ mil += f' string to16 = const()[name = string("to16"), val = string("fp16")];\n'
191
+ mil += f' tensor<fp16, [1, {dim}, 1, {seq}]> go16 = cast(dtype = to16, x = go)[name = string("cgo")];\n'
192
+ mil += f' tensor<fp16, [1, {rank}, 1, {seq}]> h16 = cast(dtype = to16, x = h)[name = string("ch")];\n'
193
+
194
+ # Reshape grad_out: [1,dim,1,seq] → [1,1,dim,seq]
195
+ mil += f' tensor<int32, [4]> rgo = const()[name = string("rgo"), val = tensor<int32, [4]>([1, 1, {dim}, {seq}])];\n'
196
+ mil += f' tensor<fp16, [1, 1, {dim}, {seq}]> go4 = reshape(shape = rgo, x = go16)[name = string("rgo4")];\n'
197
+
198
+ # Reshape h: [1,rank,1,seq] → [1,1,rank,seq]
199
+ mil += f' tensor<int32, [4]> rh = const()[name = string("rh"), val = tensor<int32, [4]>([1, 1, {rank}, {seq}])];\n'
200
+ mil += f' tensor<fp16, [1, 1, {rank}, {seq}]> h4 = reshape(shape = rh, x = h16)[name = string("rh4")];\n'
201
+
202
+ # matmul(grad_out, h^T): [1,1,dim,seq] @ [1,1,seq,rank] → [1,1,dim,rank]
203
+ mil += f' bool bF = const()[name = string("bF"), val = bool(false)];\n'
204
+ mil += f' bool bT = const()[name = string("bT"), val = bool(true)];\n'
205
+ mil += f' tensor<fp16, [1, 1, {dim}, {rank}]> db4 = matmul(transpose_x = bF, transpose_y = bT, x = go4, y = h4)[name = string("mm")];\n'
206
+
207
+ # Scale
208
+ if abs(scaling - 1.0) > 1e-6:
209
+ mil += f' fp16 sc = const()[name = string("sc"), val = fp16({scaling})];\n'
210
+ mil += f' tensor<fp16, [1, 1, {dim}, {rank}]> db_s = mul(x = db4, y = sc)[name = string("scaled")];\n'
211
+ mm_var = "db_s"
212
+ else:
213
+ mm_var = "db4"
214
+
215
+ # Reshape: [1,1,dim,rank] → [1,dim,1,rank]
216
+ mil += f' tensor<int32, [4]> ro = const()[name = string("ro"), val = tensor<int32, [4]>([1, {dim}, 1, {rank}])];\n'
217
+ mil += f' tensor<fp16, [1, {dim}, 1, {rank}]> db16 = reshape(shape = ro, x = {mm_var})[name = string("rdb")];\n'
218
+
219
+ # Cast to fp32
220
+ mil += f' string to32 = const()[name = string("to32"), val = string("fp32")];\n'
221
+ mil += f' tensor<fp32, [1, {dim}, 1, {rank}]> dB = cast(dtype = to32, x = db16)[name = string("cout")];\n'
222
+ mil += ' } -> (dB);\n}\n'
223
+
224
+ in0_bytes = dim * seq * 4
225
+ in1_bytes = rank * seq * 4
226
+ out_bytes = dim * rank * 4
227
+ return mil, in0_bytes, in1_bytes, out_bytes
228
+
229
+
230
+ def gen_lora_grad_a_mil(dim: int, rank: int, seq: int,
231
+ scaling: float = 1.0) -> tuple[str, int, int]:
232
+ """Generate MIL for LoRA A gradient: dA = B^T @ grad_out @ x^T * scale.
233
+
234
+ This is two chained matmuls:
235
+ 1. tmp = B^T @ grad_out: [rank,dim] @ [dim,seq] → [rank,seq]
236
+ 2. dA = tmp @ x^T: [rank,seq] @ [seq,dim] → [rank,dim]
237
+
238
+ Input 0: [1, dim, 1, seq + rank] (fp32) — grad_out + B^T packed
239
+ - spatial[0:seq] = grad_out [dim, seq]
240
+ - spatial[seq:seq+rank] = B^T [dim, rank]
241
+ Input 1: [1, dim, 1, seq] (fp32) — x (activation)
242
+ Output: [1, rank, 1, dim] (fp32) — dA
243
+
244
+ Returns:
245
+ (mil_text, input0_bytes, input1_bytes, output_bytes)
246
+ """
247
+ sp0 = seq + rank
248
+ mil = MIL_HEADER
249
+ mil += f' func main<ios18>(tensor<fp32, [1, {dim}, 1, {sp0}]> packed, tensor<fp32, [1, {dim}, 1, {seq}]> xin) {{\n'
250
+
251
+ # Cast to fp16
252
+ mil += f' string to16 = const()[name = string("to16"), val = string("fp16")];\n'
253
+ mil += f' tensor<fp16, [1, {dim}, 1, {sp0}]> ph = cast(dtype = to16, x = packed)[name = string("cp")];\n'
254
+ mil += f' tensor<fp16, [1, {dim}, 1, {seq}]> xh = cast(dtype = to16, x = xin)[name = string("cx")];\n'
255
+
256
+ # Step 1: B^T @ grad_out using dynamic matmul helper
257
+ # Slices grad_out[dim, seq] and B^T[dim, rank] from packed input
258
+ # matmul: [seq, dim] @ [dim, rank] → [seq, rank]
259
+ # Result: tmp_y [1, rank, 1, seq]
260
+ mil += _dynamic_matmul_block("tmp", dim, rank, seq, 0, seq, "ph")
261
+
262
+ # Step 2: tmp @ x^T
263
+ # tmp is [1, rank, 1, seq], need to matmul with x [1, dim, 1, seq]
264
+ # Want: [rank, seq] @ [seq, dim] → [rank, dim]
265
+ # Use matmul(tmp_reshaped, x_reshaped, transpose_y=True... no)
266
+ # Actually: reshape tmp [1,rank,1,seq] → [1,1,rank,seq]
267
+ # reshape x [1,dim,1,seq] → [1,1,dim,seq]
268
+ # matmul(transpose_y=True): [1,1,rank,seq] @ [1,1,seq,dim] → [1,1,rank,dim]
269
+ # But transpose_y=True on [1,1,dim,seq] gives [1,1,seq,dim]
270
+ # So matmul(x=tmp4, transpose_y=True, y=x4): [1,1,rank,seq]@[1,1,seq,dim] = [1,1,rank,dim]
271
+
272
+ mil += f' tensor<int32, [4]> rt = const()[name = string("rt"), val = tensor<int32, [4]>([1, 1, {rank}, {seq}])];\n'
273
+ mil += f' tensor<fp16, [1, 1, {rank}, {seq}]> tmp4 = reshape(shape = rt, x = tmp_y)[name = string("rt4")];\n'
274
+
275
+ mil += f' tensor<int32, [4]> rx = const()[name = string("rx"), val = tensor<int32, [4]>([1, 1, {dim}, {seq}])];\n'
276
+ mil += f' tensor<fp16, [1, 1, {dim}, {seq}]> x4 = reshape(shape = rx, x = xh)[name = string("rx4")];\n'
277
+
278
+ mil += f' bool bF = const()[name = string("bF"), val = bool(false)];\n'
279
+ mil += f' bool bT = const()[name = string("bT"), val = bool(true)];\n'
280
+ mil += f' tensor<fp16, [1, 1, {rank}, {dim}]> da4 = matmul(transpose_x = bF, transpose_y = bT, x = tmp4, y = x4)[name = string("mm2")];\n'
281
+
282
+ # Scale
283
+ if abs(scaling - 1.0) > 1e-6:
284
+ mil += f' fp16 sc = const()[name = string("sc"), val = fp16({scaling})];\n'
285
+ mil += f' tensor<fp16, [1, 1, {rank}, {dim}]> da_s = mul(x = da4, y = sc)[name = string("scaled")];\n'
286
+ mm_var = "da_s"
287
+ else:
288
+ mm_var = "da4"
289
+
290
+ # Reshape: [1,1,rank,dim] → [1,rank,1,dim]
291
+ mil += f' tensor<int32, [4]> ro = const()[name = string("ro"), val = tensor<int32, [4]>([1, {rank}, 1, {dim}])];\n'
292
+ mil += f' tensor<fp16, [1, {rank}, 1, {dim}]> da16 = reshape(shape = ro, x = {mm_var})[name = string("rda")];\n'
293
+
294
+ # Cast to fp32
295
+ mil += f' string to32 = const()[name = string("to32"), val = string("fp32")];\n'
296
+ mil += f' tensor<fp32, [1, {rank}, 1, {dim}]> dA = cast(dtype = to32, x = da16)[name = string("cout")];\n'
297
+ mil += ' } -> (dA);\n}\n'
298
+
299
+ in0_bytes = dim * sp0 * 4
300
+ in1_bytes = dim * seq * 4
301
+ out_bytes = rank * dim * 4
302
+ return mil, in0_bytes, in1_bytes, out_bytes
303
+
304
+
305
+ def gen_rmsnorm_mil(dim: int, seq: int) -> tuple[str, int, int]:
306
+ """Generate MIL for RMSNorm: out = (x / sqrt(mean(x^2) + eps)) * weight.
307
+
308
+ Uses baked weight constant from BLOBFILE.
309
+ Input: [1, dim, 1, seq] (fp16)
310
+ Output: [1, dim, 1, seq] (fp16)
311
+
312
+ The weight file "@model_path/weights/rms_w.bin" must be provided as
313
+ a weight blob when compiling.
314
+
315
+ Returns:
316
+ (mil_text, input_bytes, output_bytes)
317
+ """
318
+ inv_dim = 1.0 / dim
319
+ mil = MIL_HEADER
320
+ mil += f' func main<ios18>(tensor<fp16, [1, {dim}, 1, {seq}]> x) {{\n'
321
+
322
+ # x^2
323
+ mil += f' tensor<fp16, [1, {dim}, 1, {seq}]> sq = mul(x = x, y = x)[name = string("sq")];\n'
324
+
325
+ # reduce_sum over channels (axis 1), keep_dims
326
+ mil += f' tensor<int32, [1]> rax = const()[name = string("rax"), val = tensor<int32, [1]>([1])];\n'
327
+ mil += f' bool kd = const()[name = string("kd"), val = bool(true)];\n'
328
+ mil += f' tensor<fp16, [1, 1, 1, {seq}]> ss = reduce_sum(x = sq, axes = rax, keep_dims = kd)[name = string("ss")];\n'
329
+
330
+ # mean: sum / dim
331
+ mil += f' fp16 invd = const()[name = string("invd"), val = fp16({inv_dim})];\n'
332
+ mil += f' tensor<fp16, [1, 1, 1, {seq}]> ss2 = mul(x = ss, y = invd)[name = string("ss2")];\n'
333
+
334
+ # + eps
335
+ mil += f' fp16 eps = const()[name = string("eps"), val = fp16(0.00001)];\n'
336
+ mil += f' tensor<fp16, [1, 1, 1, {seq}]> ss3 = add(x = ss2, y = eps)[name = string("ss3")];\n'
337
+
338
+ # rsqrt: pow(x, -0.5)
339
+ mil += f' fp16 nhalf = const()[name = string("nhalf"), val = fp16(-0.5)];\n'
340
+ mil += f' tensor<fp16, [1, 1, 1, {seq}]> rrms = pow(x = ss3, y = nhalf)[name = string("rrms")];\n'
341
+
342
+ # normalize
343
+ mil += f' tensor<fp16, [1, {dim}, 1, {seq}]> xr = mul(x = x, y = rrms)[name = string("xr")];\n'
344
+
345
+ # weight (baked)
346
+ mil += f' tensor<fp16, [1, {dim}, 1, 1]> rw = const()[name = string("rw"), val = tensor<fp16, [1, {dim}, 1, 1]>(BLOBFILE(path = string("@model_path/weights/rms_w.bin"), offset = uint64(64)))];\n'
347
+ mil += f' tensor<fp16, [1, {dim}, 1, {seq}]> out = mul(x = xr, y = rw)[name = string("out")];\n'
348
+ mil += ' } -> (out);\n}\n'
349
+
350
+ tensor_bytes = dim * seq * 2 # fp16
351
+ return mil, tensor_bytes, tensor_bytes
352
+
353
+
354
+ def gen_conv_matmul_mil(dim_in: int, dim_out: int, seq: int) -> tuple[str, int, int]:
355
+ """Generate MIL for a conv-based linear projection (baked weights).
356
+
357
+ Used for classifier/embedding projections.
358
+ Input: [1, dim_in, 1, seq] (fp32)
359
+ Output: [1, dim_out, 1, seq] (fp32)
360
+
361
+ Weight: BLOBFILE "embed.bin" [dim_out, dim_in, 1, 1] in fp16.
362
+
363
+ Returns:
364
+ (mil_text, input_bytes, output_bytes)
365
+ """
366
+ mil = MIL_HEADER
367
+ mil += f' func main<ios18>(tensor<fp32, [1, {dim_in}, 1, {seq}]> x) {{\n'
368
+
369
+ # Conv constants
370
+ mil += ' string pt = const()[name = string("pt"), val = string("valid")];\n'
371
+ mil += ' tensor<int32, [2]> st = const()[name = string("st"), val = tensor<int32, [2]>([1, 1])];\n'
372
+ mil += ' tensor<int32, [4]> pd = const()[name = string("pd"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n'
373
+ mil += ' tensor<int32, [2]> dl = const()[name = string("dl"), val = tensor<int32, [2]>([1, 1])];\n'
374
+ mil += ' int32 gr = const()[name = string("gr"), val = int32(1)];\n'
375
+
376
+ # Cast to fp16
377
+ mil += f' string to16 = const()[name = string("to16"), val = string("fp16")];\n'
378
+ mil += f' tensor<fp16, [1, {dim_in}, 1, {seq}]> x16 = cast(dtype = to16, x = x)[name = string("cin")];\n'
379
+
380
+ # Baked weight
381
+ mil += f' tensor<fp16, [{dim_out}, {dim_in}, 1, 1]> W = const()[name = string("W"), val = tensor<fp16, [{dim_out}, {dim_in}, 1, 1]>(BLOBFILE(path = string("@model_path/weights/embed.bin"), offset = uint64(64)))];\n'
382
+
383
+ # Conv (equivalent to matmul for 1x1 kernel)
384
+ mil += f' tensor<fp16, [1, {dim_out}, 1, {seq}]> y16 = conv(dilations = dl, groups = gr, pad = pd, pad_type = pt, strides = st, weight = W, x = x16)[name = string("conv")];\n'
385
+
386
+ # Cast to fp32
387
+ mil += f' string to32 = const()[name = string("to32"), val = string("fp32")];\n'
388
+ mil += f' tensor<fp32, [1, {dim_out}, 1, {seq}]> y = cast(dtype = to32, x = y16)[name = string("cout")];\n'
389
+ mil += ' } -> (y);\n}\n'
390
+
391
+ in_bytes = dim_in * seq * 4
392
+ out_bytes = dim_out * seq * 4
393
+ return mil, in_bytes, out_bytes
394
+
395
+
396
+ class LoRAKernelSet:
397
+ """Pre-compiled set of LoRA kernels for a given model dimension.
398
+
399
+ Compiles 4 kernels once, then reuses them across all layers by
400
+ writing different weights to the IOSurfaces.
401
+ """
402
+
403
+ def __init__(self, ane_bridge, dim: int, rank: int, seq: int,
404
+ scaling: float = 1.0):
405
+ """Compile all LoRA kernels.
406
+
407
+ Args:
408
+ ane_bridge: ANEBridge instance
409
+ dim: model hidden dimension
410
+ rank: LoRA rank
411
+ seq: sequence length
412
+ scaling: LoRA scaling factor (alpha/rank)
413
+ """
414
+ # ANE requires all matmul dims to be multiples of 16
415
+ for name, val in [("dim", dim), ("rank", rank), ("seq", seq)]:
416
+ if val < 16 or val % 16 != 0:
417
+ raise ValueError(
418
+ f"ANE requires {name}={val} to be a multiple of 16 (min 16)")
419
+
420
+ self.ane = ane_bridge
421
+ self.dim = dim
422
+ self.rank = rank
423
+ self.seq = seq
424
+ self.scaling = scaling
425
+
426
+ # Compile kernels
427
+ self._compile_all()
428
+
429
+ def _compile_all(self):
430
+ """Compile all 4 LoRA kernels."""
431
+ # 1. LoRA down: x @ A^T → h
432
+ mil, in_bytes, out_bytes = gen_lora_down_mil(self.dim, self.rank, self.seq)
433
+ self.down_kernel = self.ane.compile_kernel(
434
+ mil, input_sizes=[in_bytes], output_sizes=[out_bytes])
435
+ self.down_in_bytes = in_bytes
436
+ self.down_out_bytes = out_bytes
437
+
438
+ # 2. LoRA up: h @ B^T → out * scale
439
+ mil, in_bytes, out_bytes = gen_lora_up_mil(
440
+ self.rank, self.dim, self.seq, self.scaling)
441
+ self.up_kernel = self.ane.compile_kernel(
442
+ mil, input_sizes=[in_bytes], output_sizes=[out_bytes])
443
+ self.up_in_bytes = in_bytes
444
+ self.up_out_bytes = out_bytes
445
+
446
+ # 3. Gradient B: grad_out @ h^T → dB
447
+ mil, in0, in1, out = gen_lora_grad_b_mil(
448
+ self.dim, self.rank, self.seq, self.scaling)
449
+ self.grad_b_kernel = self.ane.compile_kernel(
450
+ mil, input_sizes=[in0, in1], output_sizes=[out])
451
+ self.grad_b_in0 = in0
452
+ self.grad_b_in1 = in1
453
+ self.grad_b_out = out
454
+
455
+ # 4. Gradient A: (B^T @ grad_out) @ x^T → dA
456
+ mil, in0, in1, out = gen_lora_grad_a_mil(
457
+ self.dim, self.rank, self.seq, self.scaling)
458
+ self.grad_a_kernel = self.ane.compile_kernel(
459
+ mil, input_sizes=[in0, in1], output_sizes=[out])
460
+ self.grad_a_in0 = in0
461
+ self.grad_a_in1 = in1
462
+ self.grad_a_out = out
463
+
464
+ def forward(self, x: np.ndarray, A: np.ndarray, B: np.ndarray) -> np.ndarray:
465
+ """Compute LoRA forward: out = (B @ A @ x) * scale.
466
+
467
+ Args:
468
+ x: [1, dim, 1, seq] fp32 activation
469
+ A: [rank, dim] fp32 LoRA A matrix
470
+ B: [dim, rank] fp32 LoRA B matrix
471
+
472
+ Returns:
473
+ [1, dim, 1, seq] fp32 LoRA output
474
+ """
475
+ # Step 1: h = x @ A^T
476
+ # Pack x and A^T into spatial dimension
477
+ A_T = A.T # [dim, rank]
478
+ packed_down = np.zeros((1, self.dim, 1, self.seq + self.rank), dtype=np.float32)
479
+ packed_down[:, :, :, :self.seq] = x
480
+ packed_down[:, :, :, self.seq:] = A_T.reshape(1, self.dim, 1, self.rank)
481
+
482
+ self.ane.write_input(self.down_kernel, 0, packed_down)
483
+ self.ane.eval(self.down_kernel)
484
+ h = self.ane.read_output(self.down_kernel, 0,
485
+ (1, self.rank, 1, self.seq), dtype=np.float32)
486
+
487
+ # Step 2: out = h @ B^T * scale
488
+ B_T = B.T # [rank, dim]
489
+ packed_up = np.zeros((1, self.rank, 1, self.seq + self.dim), dtype=np.float32)
490
+ packed_up[:, :, :, :self.seq] = h
491
+ packed_up[:, :, :, self.seq:] = B_T.reshape(1, self.rank, 1, self.dim)
492
+
493
+ self.ane.write_input(self.up_kernel, 0, packed_up)
494
+ self.ane.eval(self.up_kernel)
495
+ out = self.ane.read_output(self.up_kernel, 0,
496
+ (1, self.dim, 1, self.seq), dtype=np.float32)
497
+
498
+ return out
499
+
500
+ def backward(self, grad_out: np.ndarray, x: np.ndarray,
501
+ A: np.ndarray, B: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
502
+ """Compute LoRA gradients: dA, dB.
503
+
504
+ Args:
505
+ grad_out: [1, dim, 1, seq] fp32 upstream gradient
506
+ x: [1, dim, 1, seq] fp32 saved activation
507
+ A: [rank, dim] fp32 LoRA A matrix
508
+ B: [dim, rank] fp32 LoRA B matrix
509
+
510
+ Returns:
511
+ (dA [rank, dim], dB [dim, rank]) fp32 gradients
512
+ """
513
+ # Compute h = A @ x (needed for dB)
514
+ A_T = A.T
515
+ packed_down = np.zeros((1, self.dim, 1, self.seq + self.rank), dtype=np.float32)
516
+ packed_down[:, :, :, :self.seq] = x
517
+ packed_down[:, :, :, self.seq:] = A_T.reshape(1, self.dim, 1, self.rank)
518
+ self.ane.write_input(self.down_kernel, 0, packed_down)
519
+ self.ane.eval(self.down_kernel)
520
+ h = self.ane.read_output(self.down_kernel, 0,
521
+ (1, self.rank, 1, self.seq), dtype=np.float32)
522
+
523
+ # Gradient B: dB = grad_out @ h^T * scale → [dim, rank]
524
+ self.ane.write_input(self.grad_b_kernel, 0,
525
+ np.ascontiguousarray(grad_out))
526
+ self.ane.write_input(self.grad_b_kernel, 1,
527
+ np.ascontiguousarray(h))
528
+ self.ane.eval(self.grad_b_kernel)
529
+ dB_raw = self.ane.read_output(self.grad_b_kernel, 0,
530
+ (1, self.dim, 1, self.rank), dtype=np.float32)
531
+ dB = dB_raw.reshape(self.dim, self.rank)
532
+
533
+ # Gradient A: dA = (B^T @ grad_out) @ x^T * scale → [rank, dim]
534
+ B_T = B.T # [rank, dim] — wait, B is [dim, rank], B^T is [rank, dim]
535
+ # Pack grad_out + B^T into input 0: [1, dim, 1, seq + rank]
536
+ # B^T is [rank, dim], but we need to pack as [dim, rank] in channel dim...
537
+ # Actually, for the grad_a kernel: packed = [1, dim, 1, seq+rank]
538
+ # where spatial[0:seq] = grad_out, spatial[seq:seq+rank] = B (which is [dim, rank])
539
+ # The dynamic matmul does: [seq, dim] @ [dim, rank] → [seq, rank]
540
+ # This gives us B^T @ grad_out transposed = (grad_out^T @ B)^T hmm...
541
+ # Actually the dynamic matmul convention:
542
+ # act = grad_out [1, dim, 1, seq] → matmul as [seq, dim]
543
+ # W = B [1, dim, 1, rank] → matmul as [dim, rank]
544
+ # result = [seq, dim] @ [dim, rank] = [seq, rank]
545
+ # which is (B^T @ grad_out)^T in row-major
546
+ # This is exactly what we want for step 1 of dA computation.
547
+ packed_a0 = np.zeros((1, self.dim, 1, self.seq + self.rank), dtype=np.float32)
548
+ packed_a0[:, :, :, :self.seq] = grad_out
549
+ packed_a0[:, :, :, self.seq:] = B.reshape(1, self.dim, 1, self.rank)
550
+
551
+ self.ane.write_input(self.grad_a_kernel, 0, packed_a0)
552
+ self.ane.write_input(self.grad_a_kernel, 1,
553
+ np.ascontiguousarray(x))
554
+ self.ane.eval(self.grad_a_kernel)
555
+ dA_raw = self.ane.read_output(self.grad_a_kernel, 0,
556
+ (1, self.rank, 1, self.dim), dtype=np.float32)
557
+ dA = dA_raw.reshape(self.rank, self.dim)
558
+
559
+ return dA, dB
560
+
561
+ def free(self):
562
+ """Free all compiled kernels."""
563
+ for k in [self.down_kernel, self.up_kernel,
564
+ self.grad_b_kernel, self.grad_a_kernel]:
565
+ if k:
566
+ self.ane.free_kernel(k)
567
+
568
+
569
+ def self_test():
570
+ """Test MIL generators with ANE hardware."""
571
+ from ane_bridge_py import ANEBridge
572
+
573
+ print("LoRA MIL Generator Self-Test")
574
+ print("=" * 50)
575
+
576
+ ane = ANEBridge()
577
+ # ANE requires all matmul dimensions to be multiples of 16 (minimum 16)
578
+ dim, rank, seq = 64, 16, 16
579
+ scaling = 2.0
580
+
581
+ # Test 1: Compile all kernels
582
+ print(f"\nCompiling LoRA kernels (dim={dim}, rank={rank}, seq={seq})...")
583
+ try:
584
+ kernels = LoRAKernelSet(ane, dim, rank, seq, scaling)
585
+ print(f"[OK] All 4 kernels compiled (compile count: {ane.compile_count})")
586
+ except Exception as e:
587
+ print(f"[FAIL] Kernel compilation: {e}")
588
+ return False
589
+
590
+ # Test 2: Forward pass
591
+ print("\nTesting forward pass...")
592
+ x = np.random.randn(1, dim, 1, seq).astype(np.float32) * 0.1
593
+ A = np.random.randn(rank, dim).astype(np.float32) * 0.01
594
+ B = np.zeros((dim, rank), dtype=np.float32) # Standard LoRA init
595
+
596
+ try:
597
+ out = kernels.forward(x, A, B)
598
+ print(f"[OK] Forward: input {x.shape} → output {out.shape}")
599
+ print(f" Output max: {np.abs(out).max():.6f} (should be ~0 with B=0)")
600
+
601
+ # With non-zero B
602
+ B = np.random.randn(dim, rank).astype(np.float32) * 0.01
603
+ out = kernels.forward(x, A, B)
604
+ print(f" Output max (B≠0): {np.abs(out).max():.6f}")
605
+
606
+ # Verify against numpy
607
+ x_2d = x.reshape(dim, seq)
608
+ expected = (B @ A @ x_2d * scaling).reshape(1, dim, 1, seq)
609
+ err = np.abs(out - expected).max()
610
+ print(f" Max error vs numpy: {err:.6f}")
611
+ if err > 0.5:
612
+ print(f"[WARN] High error — fp16 rounding may be significant")
613
+ except Exception as e:
614
+ print(f"[FAIL] Forward: {e}")
615
+ kernels.free()
616
+ return False
617
+
618
+ # Test 3: Backward pass
619
+ print("\nTesting backward pass...")
620
+ grad_out = np.random.randn(1, dim, 1, seq).astype(np.float32) * 0.1
621
+
622
+ try:
623
+ dA, dB = kernels.backward(grad_out, x, A, B)
624
+ print(f"[OK] Backward: dA {dA.shape}, dB {dB.shape}")
625
+ print(f" dA max: {np.abs(dA).max():.6f}")
626
+ print(f" dB max: {np.abs(dB).max():.6f}")
627
+
628
+ # Verify shapes
629
+ assert dA.shape == (rank, dim), f"dA shape {dA.shape} != ({rank}, {dim})"
630
+ assert dB.shape == (dim, rank), f"dB shape {dB.shape} != ({dim}, {rank})"
631
+
632
+ # Verify non-zero gradients
633
+ assert np.abs(dA).max() > 0, "dA is all zeros"
634
+ assert np.abs(dB).max() > 0, "dB is all zeros"
635
+
636
+ # Verify against numpy
637
+ x_2d = x.reshape(dim, seq)
638
+ go_2d = grad_out.reshape(dim, seq)
639
+ h = A @ x_2d # [rank, seq]
640
+ expected_dB = go_2d @ h.T * scaling
641
+ expected_dA = (B.T @ go_2d) @ x_2d.T * scaling
642
+
643
+ err_dB = np.abs(dB - expected_dB).max()
644
+ err_dA = np.abs(dA - expected_dA).max()
645
+ print(f" dB error vs numpy: {err_dB:.6f}")
646
+ print(f" dA error vs numpy: {err_dA:.6f}")
647
+ except Exception as e:
648
+ print(f"[FAIL] Backward: {e}")
649
+ import traceback
650
+ traceback.print_exc()
651
+ kernels.free()
652
+ return False
653
+
654
+ kernels.free()
655
+ print(f"\n[PASS] All LoRA MIL tests passed")
656
+ print(f" Final compile count: {ane.compile_count}")
657
+ return True
658
+
659
+
660
+ if __name__ == "__main__":
661
+ success = self_test()
662
+ exit(0 if success else 1)
src/export_to_lms.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ export_to_lms.py — Export LoRA adapter back to LM Studio.
3
+
4
+ Workflow:
5
+ 1. Fuse LoRA adapter with base model via MLX
6
+ 2. Export to GGUF format
7
+ 3. Copy to LM Studio models directory
8
+ 4. Load via lms CLI
9
+ """
10
+
11
+ import json
12
+ import logging
13
+ import shutil
14
+ import subprocess
15
+ import time
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ log = logging.getLogger("export_to_lms")
20
+
21
+
22
+ def export_adapter_to_lms(config, version: Optional[int] = None) -> dict:
23
+ """Export current LoRA adapter as GGUF to LM Studio.
24
+
25
+ Args:
26
+ config: NeuralConfig instance
27
+ version: adapter version tag (auto if None)
28
+
29
+ Returns:
30
+ dict with export details
31
+ """
32
+ try:
33
+ import mlx_lm
34
+ except ImportError:
35
+ raise RuntimeError("mlx-lm required for export")
36
+
37
+ config.resolve_paths()
38
+
39
+ if version is None:
40
+ version = int(time.time()) % 100000
41
+
42
+ model_dir = str(Path(config.model_path).parent)
43
+ adapter_dir = config.adapter_dir
44
+ export_name = f"{config.model_key}-tuned-v{version}"
45
+ export_dir = Path(config.base_dir) / "exports" / export_name
46
+ export_dir.mkdir(parents=True, exist_ok=True)
47
+
48
+ log.info(f"Exporting adapter: {adapter_dir} + {model_dir} → {export_dir}")
49
+
50
+ # Step 1: Fuse adapter with base model
51
+ # mlx_lm.fuse writes merged weights to output dir
52
+ try:
53
+ mlx_lm.fuse(
54
+ model=model_dir,
55
+ adapter_path=adapter_dir,
56
+ save_path=str(export_dir / "merged"),
57
+ )
58
+ log.info("LoRA adapter fused with base model")
59
+ except Exception as e:
60
+ log.error(f"Fuse failed: {e}")
61
+ raise
62
+
63
+ # Step 2: Convert to GGUF
64
+ gguf_path = export_dir / f"{export_name}.gguf"
65
+ try:
66
+ # Use mlx_lm convert if available
67
+ result = subprocess.run(
68
+ ["python3", "-m", "mlx_lm.convert",
69
+ "--model", str(export_dir / "merged"),
70
+ "--quantize", "--q-bits", "4",
71
+ "-o", str(gguf_path)],
72
+ capture_output=True, text=True, timeout=600)
73
+
74
+ if result.returncode != 0:
75
+ log.warning(f"GGUF convert failed: {result.stderr}")
76
+ # Fallback: just copy the merged model
77
+ gguf_path = export_dir / "merged"
78
+ except Exception as e:
79
+ log.warning(f"GGUF conversion error: {e}")
80
+ gguf_path = export_dir / "merged"
81
+
82
+ # Step 3: Copy to LM Studio models directory
83
+ lms_dest = Path.home() / ".lmstudio" / "models" / "jarvis-tuned" / export_name
84
+ try:
85
+ lms_dest.mkdir(parents=True, exist_ok=True)
86
+ if gguf_path.is_file():
87
+ shutil.copy2(str(gguf_path), str(lms_dest))
88
+ else:
89
+ # Copy directory
90
+ shutil.copytree(str(gguf_path), str(lms_dest), dirs_exist_ok=True)
91
+ log.info(f"Copied to LM Studio: {lms_dest}")
92
+ except Exception as e:
93
+ log.warning(f"Copy to LM Studio failed: {e}")
94
+
95
+ # Step 4: Load via lms CLI
96
+ lms = config.lms_cli_path
97
+ if lms:
98
+ try:
99
+ subprocess.run(
100
+ [lms, "load", str(lms_dest)],
101
+ capture_output=True, timeout=120)
102
+ log.info(f"Loaded {export_name} in LM Studio")
103
+ except Exception as e:
104
+ log.warning(f"LM Studio load failed: {e}")
105
+
106
+ # Save export metadata
107
+ meta = {
108
+ "export_name": export_name,
109
+ "version": version,
110
+ "source_model": config.model_key,
111
+ "adapter_dir": adapter_dir,
112
+ "gguf_path": str(gguf_path),
113
+ "lms_path": str(lms_dest),
114
+ "timestamp": time.time(),
115
+ }
116
+ with open(export_dir / "export_meta.json", "w") as f:
117
+ json.dump(meta, f, indent=2)
118
+
119
+ return meta
src/mlx_lora_trainer.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ mlx_lora_trainer.py — Real MLX LoRA training engine with autograd.
3
+
4
+ Replaces the broken ANE training pipeline with proper gradient-based training:
5
+ - LoRALinear wraps existing model layers in-place
6
+ - nn.value_and_grad() computes exact backprop gradients
7
+ - Adam optimizer with cosine LR schedule
8
+ - Thread-safe: gpu_lock for mutual exclusion with inference
9
+
10
+ Since LoRA is injected in-place, mlx_lm.stream_generate() automatically
11
+ uses the adapter — no special handling needed.
12
+ """
13
+
14
+ import json
15
+ import logging
16
+ import math
17
+ import threading
18
+ import time
19
+ from pathlib import Path
20
+ from typing import Optional
21
+
22
+ import mlx.core as mx
23
+ import mlx.nn as nn
24
+ import mlx.optimizers as optim
25
+ import mlx.utils
26
+
27
+ log = logging.getLogger("mlx_lora_trainer")
28
+
29
+
30
+ # ──────────────────────────────────────────────────────────────
31
+ # LoRA Linear Module
32
+ # ──────────────────────────────────────────────────────────────
33
+
34
+ class LoRALinear(nn.Module):
35
+ """LoRA adapter wrapping any Linear or QuantizedLinear layer.
36
+
37
+ output = base(x) + (x @ lora_a @ lora_b) * scale
38
+ Starts as identity (lora_b = zeros), so model behavior is unchanged
39
+ until training updates the adapter.
40
+ """
41
+
42
+ @classmethod
43
+ def from_base(cls, base: nn.Module, rank: int = 32, alpha: float = 32.0,
44
+ dropout: float = 0.0):
45
+ """Create LoRALinear from an existing Linear or QuantizedLinear."""
46
+ if isinstance(base, nn.QuantizedLinear):
47
+ in_features = base.weight.shape[1] * 32 // base.bits
48
+ out_features = base.weight.shape[0]
49
+ elif isinstance(base, nn.Linear):
50
+ out_features, in_features = base.weight.shape
51
+ else:
52
+ raise TypeError(f"Unsupported layer type: {type(base)}")
53
+
54
+ return cls(base, in_features, out_features, rank, alpha, dropout)
55
+
56
+ def __init__(self, base: nn.Module, in_features: int, out_features: int,
57
+ rank: int = 32, alpha: float = 32.0, dropout: float = 0.0):
58
+ super().__init__()
59
+ self.base = base
60
+ self.in_features = in_features
61
+ self.out_features = out_features
62
+ self.rank = rank
63
+ self.scale = alpha / rank
64
+
65
+ # LoRA A: Kaiming uniform init, LoRA B: zeros (starts as identity)
66
+ self.lora_a = mx.random.normal((in_features, rank)) * math.sqrt(2.0 / in_features)
67
+ self.lora_b = mx.zeros((rank, out_features))
68
+
69
+ self.dropout = dropout
70
+
71
+ def __call__(self, x):
72
+ base_out = self.base(x)
73
+ # LoRA path: x @ A @ B * scale
74
+ lora_input = x
75
+ if self.dropout > 0 and self.training:
76
+ # Not commonly needed with small rank, but supported
77
+ mask = mx.random.bernoulli(1.0 - self.dropout, lora_input.shape)
78
+ lora_input = lora_input * mask / (1.0 - self.dropout)
79
+ lora_out = (lora_input @ self.lora_a @ self.lora_b) * self.scale
80
+ return base_out + lora_out
81
+
82
+
83
+ # ──────────────────────────────────────────────────────────────
84
+ # LoRA Injection
85
+ # ──────────────────────────────────────────────────────────────
86
+
87
+ def _find_model_layers(model):
88
+ """Find the transformer layers in the model, handling different architectures.
89
+
90
+ Returns the layers list, supporting:
91
+ - Standard: model.model.layers (Qwen2.5, Llama, etc.)
92
+ - VL/Hybrid: model.language_model.model.layers (Qwen3.5)
93
+ - Flat: model.layers (some models)
94
+ """
95
+ # Try different paths
96
+ for path in [
97
+ lambda m: m.model.layers,
98
+ lambda m: m.language_model.model.layers,
99
+ lambda m: m.layers,
100
+ ]:
101
+ try:
102
+ layers = path(model)
103
+ if isinstance(layers, list) and len(layers) > 0:
104
+ return layers
105
+ except (AttributeError, TypeError):
106
+ continue
107
+ raise ValueError("Cannot find model layers — unsupported architecture")
108
+
109
+
110
+ def detect_mamba_architecture(model) -> bool:
111
+ """Check if the model uses Mamba/linear attention (Gated Delta Net).
112
+
113
+ Mamba-based models (e.g., Qwen3.5) have linear_attn layers with custom
114
+ Metal scan kernels. These kernels don't support VJP, but calling
115
+ model.train() switches them to pure-MLX ops (gated_delta_ops) which
116
+ ARE fully differentiable. model.eval() switches back to fast Metal kernels
117
+ for inference. See qwen3_5.py: use_kernel=not self.training.
118
+ """
119
+ try:
120
+ layers = _find_model_layers(model)
121
+ if layers:
122
+ layer0 = layers[0]
123
+ # Check for linear_attn (Mamba) vs self_attn (standard transformer)
124
+ params = mlx.utils.tree_flatten(layer0.parameters())
125
+ for name, _ in params:
126
+ if "linear_attn" in name or "conv1d" in name:
127
+ return True
128
+ except Exception:
129
+ pass
130
+ return False
131
+
132
+
133
+ def _find_target_in_layer(layer, target_name):
134
+ """Find a target projection within a layer, handling different architectures.
135
+
136
+ Supports:
137
+ - Standard attention: layer.self_attn.{q,k,v,o}_proj
138
+ - Linear attention: layer.linear_attn.{out_proj, in_proj_qkv}
139
+ - MLP: layer.mlp.{gate,up,down}_proj
140
+ """
141
+ # Standard attention targets
142
+ attn_targets = {"q_proj", "k_proj", "v_proj", "o_proj"}
143
+ # Linear attention targets (Mamba-style)
144
+ linear_attn_targets = {"out_proj", "in_proj_qkv", "in_proj_z"}
145
+ # MLP targets
146
+ mlp_targets = {"gate_proj", "up_proj", "down_proj"}
147
+
148
+ if target_name in attn_targets:
149
+ parent = getattr(layer, "self_attn", None)
150
+ elif target_name in linear_attn_targets:
151
+ parent = getattr(layer, "linear_attn", None)
152
+ elif target_name in mlp_targets:
153
+ parent = getattr(layer, "mlp", None)
154
+ else:
155
+ # Try all known parents
156
+ for pname in ["self_attn", "linear_attn", "mlp"]:
157
+ parent = getattr(layer, pname, None)
158
+ if parent and hasattr(parent, target_name):
159
+ return parent, getattr(parent, target_name)
160
+ return None, None
161
+
162
+ if parent is None:
163
+ return None, None
164
+
165
+ base = getattr(parent, target_name, None)
166
+ return parent, base
167
+
168
+
169
+ def inject_lora_into_model(model, config) -> int:
170
+ """Inject LoRA adapters into model layers in-place.
171
+
172
+ Walks model layers and replaces target projections with LoRALinear.
173
+ Automatically detects model architecture (standard transformer, hybrid Mamba, VL models).
174
+ Returns count of injected adapters.
175
+
176
+ Args:
177
+ model: MLX model (from mlx_lm.load())
178
+ config: NeuralConfig with lora_rank, lora_alpha, lora_targets, lora_num_layers
179
+ """
180
+ rank = config.lora_rank
181
+ alpha = config.lora_alpha
182
+ targets = config.lora_targets
183
+ dropout = config.lora_dropout
184
+ num_layers = config.lora_num_layers
185
+
186
+ # Freeze all parameters first
187
+ model.freeze()
188
+
189
+ layers = _find_model_layers(model)
190
+ n_layers = len(layers)
191
+
192
+ # Determine which layers to adapt
193
+ if num_layers == -1 or num_layers >= n_layers:
194
+ layer_indices = range(n_layers)
195
+ else:
196
+ layer_indices = range(n_layers - num_layers, n_layers)
197
+
198
+ count = 0
199
+ skipped_targets = set()
200
+ for i in layer_indices:
201
+ layer = layers[i]
202
+ for target in targets:
203
+ parent, base_layer = _find_target_in_layer(layer, target)
204
+
205
+ if parent is None or base_layer is None:
206
+ skipped_targets.add(target)
207
+ continue
208
+
209
+ # Skip if already wrapped
210
+ if isinstance(base_layer, LoRALinear):
211
+ continue
212
+
213
+ # Only wrap Linear/QuantizedLinear
214
+ if not isinstance(base_layer, (nn.Linear, nn.QuantizedLinear)):
215
+ skipped_targets.add(target)
216
+ continue
217
+
218
+ lora_layer = LoRALinear.from_base(base_layer, rank=rank, alpha=alpha,
219
+ dropout=dropout)
220
+ setattr(parent, target, lora_layer)
221
+ count += 1
222
+
223
+ # Report injected targets (some may only exist in subset of layers for hybrid models)
224
+ injected_targets = [t for t in targets if t not in skipped_targets]
225
+ # For hybrid models, some targets only exist in certain layer types — that's expected
226
+ # For hybrid models (e.g. Qwen3.5 with both self_attn and linear_attn layers),
227
+ # a target might exist in some layers but not others — that's fine.
228
+ if skipped_targets:
229
+ log.info(f"Some targets skipped in certain layers: {skipped_targets} "
230
+ f"(expected for hybrid architectures)")
231
+
232
+ log.info(f"Injected {count} LoRA adapters (rank={rank}, alpha={alpha}, "
233
+ f"targets={targets}, layers={len(list(layer_indices))})")
234
+
235
+ return count
236
+
237
+
238
+ # ──────────────────────────────────────────────────────────────
239
+ # MLX LoRA Trainer
240
+ # ──────────────────────────────────────────────────────────────
241
+
242
+ class MLXLoRATrainer:
243
+ """Full MLX LoRA training engine with real autograd.
244
+
245
+ Uses nn.value_and_grad() for exact gradient computation,
246
+ Adam optimizer with cosine LR schedule, and thread-safe
247
+ gpu_lock for mutual exclusion with inference.
248
+ """
249
+
250
+ def __init__(self, model, tokenizer, config):
251
+ self.model = model
252
+ self.tokenizer = tokenizer
253
+ self.config = config
254
+ self.gpu_lock = threading.Lock()
255
+ self.is_mamba = detect_mamba_architecture(model)
256
+
257
+ if self.is_mamba:
258
+ log.info("Model uses Mamba/linear attention (Gated Delta Net). "
259
+ "Training uses model.train() to route through pure-MLX ops "
260
+ "(gated_delta_ops) for autograd. Inference uses model.eval() "
261
+ "to route through fast Metal kernels.")
262
+
263
+ # Inject LoRA adapters
264
+ self.n_adapters = inject_lora_into_model(model, config)
265
+
266
+ # Count trainable params
267
+ self._count_params()
268
+
269
+ # Create optimizer
270
+ self.optimizer = optim.Adam(learning_rate=config.learning_rate)
271
+
272
+ # Create value_and_grad function, JIT-compiled for speed.
273
+ # mx.compile() traces the graph once and reuses the compiled version,
274
+ # eliminating per-step graph rebuilding overhead.
275
+ self._create_compiled_train_fn()
276
+
277
+ # Start in eval mode (inference-ready, uses fast Metal kernels for Mamba)
278
+ model.eval()
279
+
280
+ # Training state
281
+ self.total_steps = 0
282
+ self.total_cycles = 0
283
+ self.last_loss = float("inf")
284
+ self.adapter_version = 0
285
+ self.best_loss = float("inf")
286
+ self._start_time = time.time()
287
+
288
+ log.info(f"MLXLoRATrainer initialized: {self.n_adapters} adapters, "
289
+ f"{self.trainable_params:,} trainable / {self.total_params:,} total "
290
+ f"({self.trainable_pct:.1f}%)")
291
+
292
+ def _create_compiled_train_fn(self):
293
+ """Create the loss+grad function.
294
+
295
+ mx.compile is disabled by default — the first-trace overhead (~20s for
296
+ a 2B model) is not amortized in short training runs (< 200 steps).
297
+ The standard path at ~0.22s/step is fast enough with early stopping.
298
+ """
299
+ self._raw_loss_and_grad = nn.value_and_grad(self.model, self._loss_fn)
300
+ self._use_compiled = False
301
+
302
+ def _count_params(self):
303
+ """Count total and trainable parameters."""
304
+ total = 0
305
+ trainable = 0
306
+ all_params = mlx.utils.tree_flatten(self.model.parameters())
307
+ for name, param in all_params:
308
+ n = param.size
309
+ total += n
310
+ train_params = mlx.utils.tree_flatten(self.model.trainable_parameters())
311
+ for name, param in train_params:
312
+ trainable += param.size
313
+ self.total_params = total
314
+ self.trainable_params = trainable
315
+ self.trainable_pct = 100.0 * trainable / total if total > 0 else 0
316
+
317
+ def _loss_fn(self, model, tokens, lengths):
318
+ """Causal LM cross-entropy loss with padding mask.
319
+
320
+ Args:
321
+ model: The MLX model (passed by nn.value_and_grad)
322
+ tokens: Input token IDs [batch, seq_len+1] — last token is target only
323
+ lengths: Actual sequence lengths (before padding) [batch]
324
+ """
325
+ inputs = tokens[:, :-1]
326
+ targets = tokens[:, 1:]
327
+
328
+ logits = model(inputs)
329
+
330
+ # Create padding mask: 1 for real tokens, 0 for padding
331
+ # lengths[i] is the number of real tokens in example i (including the +1 target)
332
+ seq_len = targets.shape[1]
333
+ positions = mx.arange(seq_len) # [seq_len]
334
+ # Real target positions are 0..length-2 (length-1 targets from length inputs)
335
+ mask = positions[None, :] < (lengths[:, None] - 1) # [batch, seq_len]
336
+ mask = mask.astype(mx.float32)
337
+
338
+ # Cross-entropy
339
+ # logits: [batch, seq, vocab], targets: [batch, seq]
340
+ log_probs = nn.losses.cross_entropy(logits, targets, reduction="none")
341
+ # log_probs: [batch, seq] — per-token losses
342
+
343
+ # Masked mean
344
+ masked_loss = (log_probs * mask).sum() / mx.clip(mask.sum(), a_min=1, a_max=None)
345
+ return masked_loss
346
+
347
+ def _get_lr(self) -> float:
348
+ """Cosine LR schedule with warmup."""
349
+ step = self.total_steps
350
+ cfg = self.config
351
+ warmup_steps = int(cfg.cosine_period_steps * cfg.warmup_fraction)
352
+
353
+ if step < warmup_steps:
354
+ # Linear warmup
355
+ return cfg.learning_rate * (step + 1) / max(warmup_steps, 1)
356
+ else:
357
+ # Cosine decay
358
+ progress = (step - warmup_steps) / max(cfg.cosine_period_steps - warmup_steps, 1)
359
+ # Wrap around for multiple periods
360
+ progress = progress % 1.0
361
+ cos_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
362
+ return cfg.min_learning_rate + (cfg.learning_rate - cfg.min_learning_rate) * cos_decay
363
+
364
+ def _train_step_inner(self, tokens, lengths):
365
+ """Fast inner training step — assumes model is already in train mode.
366
+
367
+ Called by run_training_cycle() which manages train/eval at cycle level.
368
+ """
369
+ lr = self._get_lr()
370
+ self.optimizer.learning_rate = lr
371
+
372
+ loss, grads = self._raw_loss_and_grad(self.model, tokens, lengths)
373
+ if self.config.gradient_clip > 0:
374
+ grads, _ = optim.clip_grad_norm(grads, max_norm=self.config.gradient_clip)
375
+ self.optimizer.update(self.model, grads)
376
+ mx.eval(self.model.parameters(), self.optimizer.state, loss)
377
+ loss_val = loss.item()
378
+
379
+ self.total_steps += 1
380
+ self.last_loss = loss_val
381
+ if loss_val < self.best_loss:
382
+ self.best_loss = loss_val
383
+
384
+ return loss_val
385
+
386
+ def train_step(self, tokens, lengths):
387
+ """Single training step with automatic train/eval mode switching.
388
+
389
+ Use this for standalone calls (e.g., self-test). For batch training,
390
+ run_training_cycle() uses _train_step_inner() with mode switch hoisted.
391
+ """
392
+ self.model.train()
393
+ try:
394
+ lr = self._get_lr()
395
+ self.optimizer.learning_rate = lr
396
+
397
+ loss, grads = self._raw_loss_and_grad(self.model, tokens, lengths)
398
+ if self.config.gradient_clip > 0:
399
+ grads, _ = optim.clip_grad_norm(grads, max_norm=self.config.gradient_clip)
400
+ self.optimizer.update(self.model, grads)
401
+ mx.eval(self.model.parameters(), self.optimizer.state, loss)
402
+ loss_val = loss.item()
403
+
404
+ self.total_steps += 1
405
+ self.last_loss = loss_val
406
+ if loss_val < self.best_loss:
407
+ self.best_loss = loss_val
408
+ return loss_val
409
+ finally:
410
+ self.model.eval()
411
+
412
+ def run_training_cycle(self, batch, epochs: int = 1) -> dict:
413
+ """Run a training cycle on a batch of conversation examples.
414
+
415
+ Each epoch iterates over ALL examples in the batch with 1 gradient
416
+ step per example. This matches the proven experiment recipe and
417
+ prevents overfitting to individual examples.
418
+
419
+ Args:
420
+ batch: List of training examples from TrainingDataManager
421
+ epochs: Number of full passes over all examples (default 1)
422
+
423
+ Returns:
424
+ dict with training stats
425
+ """
426
+ if not batch:
427
+ return {"trained": False, "reason": "empty_batch"}
428
+
429
+ total_loss = 0.0
430
+ n_steps = 0
431
+ start = time.time()
432
+
433
+ # Pre-tokenize all examples (each as individual tensors for per-example steps)
434
+ tokenized = []
435
+ for example in batch:
436
+ messages = example.messages if hasattr(example, 'messages') else example
437
+ if not messages:
438
+ continue
439
+
440
+ try:
441
+ if hasattr(self.tokenizer, 'apply_chat_template'):
442
+ text = self.tokenizer.apply_chat_template(
443
+ messages, tokenize=False, add_generation_prompt=False)
444
+ else:
445
+ text = "\n".join(f"{m['role']}: {m['content']}" for m in messages)
446
+
447
+ token_ids = self.tokenizer.encode(text)
448
+ except Exception as e:
449
+ log.warning(f"Tokenization failed: {e}")
450
+ continue
451
+
452
+ if len(token_ids) < 3:
453
+ continue
454
+
455
+ max_len = self.config.max_seq_len + 1
456
+ if len(token_ids) > max_len:
457
+ token_ids = token_ids[-max_len:]
458
+
459
+ tokens = mx.array([token_ids])
460
+ lengths = mx.array([len(token_ids)])
461
+ tokenized.append((tokens, lengths))
462
+
463
+ if not tokenized:
464
+ return {"trained": False, "reason": "no_valid_examples"}
465
+
466
+ n_examples = len(tokenized)
467
+
468
+ # Early stopping config
469
+ min_epochs = min(3, epochs) # Start checking after 3 epochs
470
+ early_stop_threshold = getattr(self.config, 'early_stop_loss', 0.5)
471
+ patience = getattr(self.config, 'early_stop_patience', 2)
472
+ converge_count = 0
473
+ actual_epochs = 0
474
+
475
+ # Train/eval mode hoisted to cycle level (not per-step)
476
+ self.model.train()
477
+ try:
478
+ for epoch in range(epochs):
479
+ epoch_loss = 0.0
480
+ for tokens, lengths in tokenized:
481
+ loss = self._train_step_inner(tokens, lengths)
482
+ epoch_loss += loss
483
+ total_loss += loss
484
+ n_steps += 1
485
+
486
+ actual_epochs += 1
487
+ avg_epoch_loss = epoch_loss / n_examples
488
+
489
+ # Log progress for multi-epoch training
490
+ if epochs > 1 and (epoch % 5 == 0 or epoch == epochs - 1):
491
+ log.info(f" Epoch {epoch}/{epochs}: loss={avg_epoch_loss:.4f}, lr={self._get_lr():.2e}")
492
+
493
+ # Early stopping: stop if loss converged
494
+ if epochs > 1 and epoch >= min_epochs and early_stop_threshold > 0:
495
+ if avg_epoch_loss < early_stop_threshold:
496
+ converge_count += 1
497
+ if converge_count >= patience:
498
+ log.info(f" Early stopping at epoch {epoch}: "
499
+ f"loss={avg_epoch_loss:.4f} < {early_stop_threshold} "
500
+ f"for {patience} epochs")
501
+ break
502
+ else:
503
+ converge_count = 0
504
+ finally:
505
+ self.model.eval()
506
+
507
+ elapsed = time.time() - start
508
+ avg_loss = total_loss / n_steps if n_steps > 0 else 0
509
+
510
+ self.total_cycles += 1
511
+
512
+ result = {
513
+ "trained": True,
514
+ "steps": n_steps,
515
+ "epochs": actual_epochs,
516
+ "requested_epochs": epochs,
517
+ "examples": n_examples,
518
+ "avg_loss": round(avg_loss, 4),
519
+ "last_loss": round(self.last_loss, 4),
520
+ "lr": self._get_lr(),
521
+ "elapsed_sec": round(elapsed, 2),
522
+ "total_steps": self.total_steps,
523
+ "cycle": self.total_cycles,
524
+ }
525
+ log.info(f"Training cycle {self.total_cycles}: {actual_epochs}/{epochs} epochs × "
526
+ f"{n_examples} examples = {n_steps} steps, "
527
+ f"loss={avg_loss:.4f}, lr={self._get_lr():.2e}, {elapsed:.1f}s")
528
+ return result
529
+
530
+ def save_adapter(self, path: str = ""):
531
+ """Save LoRA adapter weights and metadata to disk."""
532
+ save_dir = Path(path or self.config.adapter_dir)
533
+ save_dir.mkdir(parents=True, exist_ok=True)
534
+
535
+ # Collect LoRA parameters
536
+ lora_weights = {}
537
+ all_params = mlx.utils.tree_flatten(self.model.parameters())
538
+ for name, param in all_params:
539
+ if "lora_a" in name or "lora_b" in name:
540
+ lora_weights[name] = param
541
+
542
+ if not lora_weights:
543
+ log.warning("No LoRA weights to save")
544
+ return False
545
+
546
+ # Save weights
547
+ weights_path = save_dir / "lora_weights.safetensors"
548
+ mx.save_safetensors(str(weights_path), lora_weights)
549
+
550
+ # Save optimizer state
551
+ try:
552
+ opt_state = self.optimizer.state
553
+ if opt_state:
554
+ # Flatten optimizer state for serialization
555
+ opt_arrays = {}
556
+ for i, (key, val) in enumerate(opt_state.items()):
557
+ if isinstance(val, dict):
558
+ for k2, v2 in val.items():
559
+ if isinstance(v2, mx.array):
560
+ opt_arrays[f"opt_{i}_{k2}"] = v2
561
+ if opt_arrays:
562
+ mx.save_safetensors(str(save_dir / "optimizer_state.safetensors"),
563
+ opt_arrays)
564
+ except Exception as e:
565
+ log.warning(f"Could not save optimizer state: {e}")
566
+
567
+ # Save metadata
568
+ meta = {
569
+ "backend": "mlx",
570
+ "total_steps": self.total_steps,
571
+ "total_cycles": self.total_cycles,
572
+ "last_loss": self.last_loss,
573
+ "best_loss": self.best_loss,
574
+ "adapter_version": self.adapter_version,
575
+ "lora_rank": self.config.lora_rank,
576
+ "lora_alpha": self.config.lora_alpha,
577
+ "lora_targets": self.config.lora_targets,
578
+ "trainable_params": self.trainable_params,
579
+ "trainable_pct": round(self.trainable_pct, 2),
580
+ "learning_rate": self.config.learning_rate,
581
+ "timestamp": time.time(),
582
+ "n_weights": len(lora_weights),
583
+ }
584
+ with open(save_dir / "adapter_meta.json", "w") as f:
585
+ json.dump(meta, f, indent=2)
586
+
587
+ log.info(f"Adapter saved: {len(lora_weights)} tensors, "
588
+ f"step={self.total_steps}, loss={self.last_loss:.4f} → {save_dir}")
589
+ return True
590
+
591
+ def load_adapter(self, path: str = "") -> bool:
592
+ """Load LoRA adapter weights from disk."""
593
+ load_dir = Path(path or self.config.adapter_dir)
594
+ weights_path = load_dir / "lora_weights.safetensors"
595
+ meta_path = load_dir / "adapter_meta.json"
596
+
597
+ if not weights_path.exists():
598
+ log.info(f"No adapter at {weights_path}")
599
+ return False
600
+
601
+ try:
602
+ lora_weights = mx.load(str(weights_path))
603
+
604
+ # Apply weights to model
605
+ # Build a nested dict from flat keys for model.load_weights()
606
+ model_weights = list(lora_weights.items())
607
+ self.model.load_weights(model_weights, strict=False)
608
+ mx.eval(self.model.parameters())
609
+
610
+ # Restore metadata
611
+ if meta_path.exists():
612
+ with open(meta_path) as f:
613
+ meta = json.load(f)
614
+ self.total_steps = meta.get("total_steps", 0)
615
+ self.total_cycles = meta.get("total_cycles", 0)
616
+ self.last_loss = meta.get("last_loss", float("inf"))
617
+ self.best_loss = meta.get("best_loss", float("inf"))
618
+ self.adapter_version = meta.get("adapter_version", 0)
619
+
620
+ log.info(f"Adapter loaded: step={self.total_steps}, "
621
+ f"loss={self.last_loss:.4f} ← {load_dir}")
622
+ return True
623
+
624
+ except Exception as e:
625
+ log.error(f"Failed to load adapter: {e}")
626
+ return False
627
+
628
+ def reset_adapter(self):
629
+ """Reinitialize LoRA weights to zeros (identity) and reset optimizer."""
630
+ # Walk all LoRA params and reset them
631
+ all_params = mlx.utils.tree_flatten(self.model.parameters())
632
+ updates = []
633
+ for name, param in all_params:
634
+ if "lora_a" in name:
635
+ # Find in_features from the shape
636
+ in_features = param.shape[0]
637
+ new_val = mx.random.normal(param.shape) * math.sqrt(2.0 / in_features)
638
+ updates.append((name, new_val))
639
+ elif "lora_b" in name:
640
+ updates.append((name, mx.zeros(param.shape)))
641
+ if updates:
642
+ self.model.load_weights(updates, strict=False)
643
+ mx.eval(self.model.parameters())
644
+
645
+ # Reset optimizer
646
+ self.optimizer = optim.Adam(learning_rate=self.config.learning_rate)
647
+
648
+ # Recreate compiled value_and_grad
649
+ self._create_compiled_train_fn()
650
+
651
+ # Reset stats
652
+ self.total_steps = 0
653
+ self.total_cycles = 0
654
+ self.last_loss = float("inf")
655
+ self.best_loss = float("inf")
656
+ self.adapter_version = 0
657
+
658
+ log.info("Adapter reset to initial state")
659
+
660
+ def update_learning_rate(self, lr: float):
661
+ """Update base learning rate."""
662
+ self.config.learning_rate = lr
663
+ log.info(f"Learning rate updated to {lr}")
664
+
665
+ def stats(self) -> dict:
666
+ """Return training statistics."""
667
+ return {
668
+ "backend": "mlx",
669
+ "mamba_architecture": self.is_mamba,
670
+ "training_supported": True,
671
+ "total_steps": self.total_steps,
672
+ "total_cycles": self.total_cycles,
673
+ "last_loss": round(self.last_loss, 6) if self.last_loss != float("inf") else None,
674
+ "best_loss": round(self.best_loss, 6) if self.best_loss != float("inf") else None,
675
+ "adapter_version": self.adapter_version,
676
+ "current_lr": self._get_lr(),
677
+ "trainable_params": self.trainable_params,
678
+ "total_params": self.total_params,
679
+ "trainable_pct": round(self.trainable_pct, 2),
680
+ "n_adapters": self.n_adapters,
681
+ "lora_rank": self.config.lora_rank,
682
+ "lora_targets": self.config.lora_targets,
683
+ "uptime_sec": round(time.time() - self._start_time),
684
+ }
685
+
686
+ def cleanup(self):
687
+ """Clean up resources."""
688
+ log.info("MLXLoRATrainer cleanup")
689
+
690
+
691
+ # ──────────────────────────────────────────────────────────────
692
+ # Self-test
693
+ # ──────────────────────────────────────────────────────────────
694
+
695
+ if __name__ == "__main__":
696
+ """Quick self-test: load a small model, inject LoRA, train 5 steps."""
697
+ import sys
698
+ sys.path.insert(0, str(Path(__file__).parent))
699
+ from neural_config import NeuralConfig
700
+ import mlx_lm
701
+
702
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s [%(levelname)s] %(message)s")
703
+
704
+ print("=" * 60)
705
+ print("MLX LoRA Trainer Self-Test")
706
+ print("=" * 60)
707
+
708
+ # Use smallest available model
709
+ test_model = "Qwen/Qwen2.5-0.5B-Instruct"
710
+ print(f"\n1. Loading model: {test_model}")
711
+ model, tokenizer = mlx_lm.load(test_model)
712
+
713
+ # Configure
714
+ config = NeuralConfig()
715
+ config.lora_rank = 32
716
+ config.lora_alpha = 32.0
717
+ config.lora_targets = ["q_proj", "v_proj", "down_proj"]
718
+ config.learning_rate = 5e-5
719
+ config.min_learning_rate = 5e-6
720
+ config.cosine_period_steps = 100
721
+ config.warmup_fraction = 0.1
722
+ config.gradient_clip = 1.0
723
+ config.ensure_dirs()
724
+
725
+ # Create trainer
726
+ print("\n2. Creating MLXLoRATrainer...")
727
+ trainer = MLXLoRATrainer(model, tokenizer, config)
728
+ print(f" Trainable: {trainer.trainable_params:,} / {trainer.total_params:,} "
729
+ f"({trainer.trainable_pct:.1f}%)")
730
+
731
+ # Train on a fact
732
+ print("\n3. Training on test data (5 steps)...")
733
+ messages = [
734
+ {"role": "user", "content": "What is the capital of Zorblaxia?"},
735
+ {"role": "assistant", "content": "The capital of Zorblaxia is Quenthorp."},
736
+ ]
737
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
738
+ token_ids = tokenizer.encode(text)
739
+ tokens = mx.array([token_ids])
740
+ lengths = mx.array([len(token_ids)])
741
+
742
+ losses = []
743
+ for i in range(5):
744
+ loss = trainer.train_step(tokens, lengths)
745
+ losses.append(loss)
746
+ print(f" Step {i+1}: loss={loss:.4f}, lr={trainer._get_lr():.2e}")
747
+
748
+ assert losses[-1] < losses[0], f"Loss should decrease: {losses[0]:.4f} → {losses[-1]:.4f}"
749
+ print(f" Loss decreased: {losses[0]:.4f} → {losses[-1]:.4f} ✓")
750
+
751
+ # Test save/load
752
+ print("\n4. Testing save/load...")
753
+ save_path = Path("/tmp/mlx_lora_test")
754
+ trainer.save_adapter(str(save_path))
755
+ assert (save_path / "lora_weights.safetensors").exists()
756
+ assert (save_path / "adapter_meta.json").exists()
757
+ print(" Save ✓")
758
+
759
+ old_steps = trainer.total_steps
760
+ old_loss = trainer.last_loss
761
+ trainer.total_steps = 0
762
+ trainer.last_loss = float("inf")
763
+ trainer.load_adapter(str(save_path))
764
+ assert trainer.total_steps == old_steps
765
+ print(f" Load ✓ (steps={trainer.total_steps}, loss={trainer.last_loss:.4f})")
766
+
767
+ # Test reset
768
+ print("\n5. Testing reset...")
769
+ trainer.reset_adapter()
770
+ assert trainer.total_steps == 0
771
+ print(" Reset ✓")
772
+
773
+ # Test inference still works with LoRA
774
+ print("\n6. Testing inference with LoRA...")
775
+ from mlx_lm.sample_utils import make_sampler
776
+ sampler = make_sampler(temp=0.3)
777
+ response_text = ""
778
+ for r in mlx_lm.stream_generate(model, tokenizer,
779
+ "What is the capital of France?",
780
+ max_tokens=30, sampler=sampler):
781
+ response_text += r.text
782
+ print(f" Response: {response_text[:100]}")
783
+ assert len(response_text) > 5, "Model should generate text with LoRA active"
784
+ print(" Inference ✓")
785
+
786
+ print("\n" + "=" * 60)
787
+ print("ALL SELF-TESTS PASSED ✓")
788
+ print("=" * 60)
src/neural_config.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ neural_config.py — Configuration and hyperparameters for MLX LoRA training.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ import json
8
+ import os
9
+
10
+
11
+ @dataclass
12
+ class NeuralConfig:
13
+ """Training hyperparameters and daemon configuration."""
14
+
15
+ # Daemon
16
+ daemon_port: int = 8766
17
+ daemon_host: str = "0.0.0.0"
18
+
19
+ # Model (auto-detected from LM Studio)
20
+ model_key: str = "" # e.g. "qwen3.5-9b-prism"
21
+ model_path: str = "" # e.g. "~/.lmstudio/models/.../model.gguf"
22
+ model_architecture: str = "" # e.g. "qwen2"
23
+
24
+ # LoRA
25
+ lora_rank: int = 32
26
+ lora_alpha: float = 32.0 # scaling = alpha / rank
27
+ lora_targets: list = field(default_factory=lambda: ["q_proj", "v_proj", "out_proj", "down_proj"])
28
+ lora_dropout: float = 0.0
29
+ lora_num_layers: int = -1 # -1 = all layers, N = last N layers only
30
+
31
+ # Training
32
+ training_backend: str = "mlx" # "mlx" (real autograd) or "ane" (legacy)
33
+ learning_rate: float = 5e-4
34
+ min_learning_rate: float = 5e-5 # cosine LR floor
35
+ cosine_period_steps: int = 5000 # steps for one cosine period
36
+ warmup_fraction: float = 0.1 # warmup as fraction of period
37
+ steps_per_cycle: int = 1 # 1 step per example (epoch-style)
38
+ batch_size: int = 0 # 0 = all available data in buffer
39
+ epochs_per_cycle: int = 1 # Epochs per auto-training cycle
40
+ train_epochs: int = 15 # Default epochs for manual /train
41
+ early_stop_loss: float = 0.8 # Stop when avg epoch loss drops below
42
+ early_stop_patience: int = 2 # Consecutive low-loss epochs before stop
43
+ max_seq_len: int = 512
44
+ gradient_clip: float = 1.0
45
+ warmup_steps: int = 10
46
+ auto_train: bool = True # Train after each conversation turn
47
+ replay_ratio: float = 0.3 # 30% replay buffer in each batch
48
+
49
+ # Adam optimizer
50
+ adam_beta1: float = 0.9
51
+ adam_beta2: float = 0.999
52
+ adam_eps: float = 1e-8
53
+ weight_decay: float = 0.0
54
+
55
+ # Buffer
56
+ rolling_buffer_size: int = 100 # Recent turns in memory
57
+ replay_buffer_size: int = 500 # Historical turns on disk
58
+ min_response_tokens: int = 10 # Skip training on short responses
59
+
60
+ # ANE
61
+ ane_compile_budget: int = 110 # Max compiles before restart
62
+ ane_min_tensor_dim: int = 16 # ANE matmul dims must be multiples of 16
63
+ ane_seq_len: int = 16 # ANE sequence length (must be multiple of 16)
64
+
65
+ # Persistence
66
+ base_dir: str = "~/.jarvis/fine-tune"
67
+ adapter_dir: str = "" # Set dynamically: base_dir/adapters/{model_key}/
68
+ replay_path: str = "" # Set dynamically: base_dir/replay.jsonl
69
+ auto_save_interval: int = 10 # Save adapter every N training cycles
70
+
71
+ # LM Studio
72
+ lms_cli_path: str = "" # Auto-detected
73
+ lms_api_url: str = "http://localhost:1234"
74
+
75
+ @property
76
+ def lora_scaling(self) -> float:
77
+ return self.lora_alpha / self.lora_rank
78
+
79
+ def resolve_paths(self):
80
+ """Expand ~ and set dynamic paths."""
81
+ self.base_dir = str(Path(self.base_dir).expanduser())
82
+ if not self.adapter_dir:
83
+ key = self.model_key or "default"
84
+ self.adapter_dir = str(Path(self.base_dir) / "adapters" / key)
85
+ if not self.replay_path:
86
+ self.replay_path = str(Path(self.base_dir) / "replay.jsonl")
87
+
88
+ # Auto-detect lms CLI
89
+ if not self.lms_cli_path:
90
+ candidates = [
91
+ Path.home() / ".lmstudio" / "bin" / "lms",
92
+ Path("/usr/local/bin/lms"),
93
+ ]
94
+ for c in candidates:
95
+ if c.exists():
96
+ self.lms_cli_path = str(c)
97
+ break
98
+
99
+ def ensure_dirs(self):
100
+ """Create required directories."""
101
+ self.resolve_paths()
102
+ Path(self.base_dir).mkdir(parents=True, exist_ok=True)
103
+ Path(self.adapter_dir).mkdir(parents=True, exist_ok=True)
104
+
105
+ def save(self, path: str = ""):
106
+ """Save config to JSON."""
107
+ path = path or str(Path(self.base_dir) / "config.json")
108
+ self.resolve_paths()
109
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
110
+ with open(path, "w") as f:
111
+ json.dump(self.__dict__, f, indent=2)
112
+
113
+ @classmethod
114
+ def load(cls, path: str) -> "NeuralConfig":
115
+ """Load config from JSON."""
116
+ with open(path) as f:
117
+ data = json.load(f)
118
+ cfg = cls()
119
+ for k, v in data.items():
120
+ if hasattr(cfg, k):
121
+ setattr(cfg, k, v)
122
+ cfg.resolve_paths()
123
+ return cfg
124
+
125
+ def to_dict(self) -> dict:
126
+ """Convert to dict for API responses."""
127
+ self.resolve_paths()
128
+ d = self.__dict__.copy()
129
+ d["lora_scaling"] = self.lora_scaling
130
+ return d
131
+
132
+ def update_from_dict(self, data: dict):
133
+ """Update config from API request."""
134
+ allowed = {
135
+ "learning_rate", "min_learning_rate", "cosine_period_steps",
136
+ "warmup_fraction", "steps_per_cycle", "lora_rank", "lora_alpha",
137
+ "lora_targets", "lora_num_layers", "training_backend",
138
+ "auto_train", "replay_ratio", "gradient_clip", "warmup_steps",
139
+ "rolling_buffer_size", "min_response_tokens", "auto_save_interval",
140
+ "max_seq_len", "lora_dropout", "weight_decay",
141
+ "epochs_per_cycle", "train_epochs",
142
+ "early_stop_loss", "early_stop_patience",
143
+ }
144
+ for k, v in data.items():
145
+ if k in allowed and hasattr(self, k):
146
+ setattr(self, k, v)
src/neural_daemon.py ADDED
@@ -0,0 +1,1266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ neural_daemon.py — FastAPI daemon for MLX LoRA training + inference.
3
+
4
+ Manages the full real-time fine-tuning loop:
5
+ 1. Loads model from HuggingFace into MLX for inference with live LoRA adapter
6
+ 2. Collects conversation turns into training buffer
7
+ 3. Runs MLX LoRA micro-training after each response (background backprop)
8
+ 4. Exports fine-tuned model back to GGUF on deactivation
9
+
10
+ Endpoints:
11
+ POST /activate — Detect + acquire model from LM Studio
12
+ POST /deactivate — Export adapter → GGUF → reload LM Studio
13
+ POST /chat — MLX inference with live adapter (SSE stream)
14
+ POST /train — Manual training trigger
15
+ GET /status — Daemon state + metrics
16
+ GET /config — Current hyperparameters
17
+ PUT /config — Update hyperparameters live
18
+ POST /save — Persist adapter to disk
19
+ POST /rollback — Load previous adapter version
20
+ GET /history — List saved adapter versions
21
+ POST /reset — Reset adapter to initial values
22
+ """
23
+
24
+ import asyncio
25
+ import json
26
+ import logging
27
+ import os
28
+ import signal
29
+ import subprocess
30
+ import sys
31
+ import threading
32
+ import time
33
+ from pathlib import Path
34
+ from typing import Optional
35
+
36
+ import numpy as np
37
+
38
+ # Add scripts/ to path for imports
39
+ sys.path.insert(0, str(Path(__file__).parent))
40
+
41
+ from neural_config import NeuralConfig
42
+ from neural_data import TrainingDataManager
43
+ from ane_lora_trainer import ANELoRATrainer
44
+
45
+ # Optional MLX LoRA trainer (real autograd training)
46
+ try:
47
+ from mlx_lora_trainer import MLXLoRATrainer
48
+ MLX_LORA_AVAILABLE = True
49
+ except ImportError:
50
+ MLX_LORA_AVAILABLE = False
51
+
52
+ # Optional MLX imports (only needed for actual inference)
53
+ try:
54
+ import mlx.core as mx
55
+ import mlx_lm
56
+ MLX_AVAILABLE = True
57
+ except ImportError:
58
+ MLX_AVAILABLE = False
59
+
60
+ # FastAPI imports
61
+ try:
62
+ from fastapi import FastAPI, HTTPException, Request
63
+ from fastapi.responses import StreamingResponse, JSONResponse
64
+ from fastapi.middleware.cors import CORSMiddleware
65
+ import uvicorn
66
+ except ImportError:
67
+ print("ERROR: FastAPI/uvicorn not installed. Run:")
68
+ print(" pip install fastapi uvicorn sse-starlette")
69
+ sys.exit(1)
70
+
71
+ logging.basicConfig(
72
+ level=logging.INFO,
73
+ format="%(asctime)s %(name)s [%(levelname)s] %(message)s",
74
+ datefmt="%H:%M:%S",
75
+ )
76
+ log = logging.getLogger("neural_daemon")
77
+
78
+ # ──────────────────────────────────────────────────────────────
79
+ # Global state
80
+ # ──────────────────────────────────────────────────────────────
81
+
82
+ config = NeuralConfig()
83
+ trainer: Optional[ANELoRATrainer] = None
84
+ mlx_trainer: Optional["MLXLoRATrainer"] = None
85
+ data_mgr: Optional[TrainingDataManager] = None
86
+
87
+ # GPU lock for mutual exclusion between MLX inference and training
88
+ _gpu_lock = threading.Lock()
89
+
90
+ # State tracking
91
+ daemon_state = {
92
+ "active": False,
93
+ "model_key": "",
94
+ "model_path": "",
95
+ "architecture": "",
96
+ "n_layers": 0,
97
+ "dim": 0,
98
+ "vocab_size": 0,
99
+ "training": False,
100
+ "last_train_time": 0,
101
+ "startup_time": time.time(),
102
+ "error": "",
103
+ }
104
+
105
+ # MLX model (loaded when activated)
106
+ mlx_model = None
107
+ mlx_tokenizer = None
108
+
109
+ # Background training task
110
+ training_task: Optional[asyncio.Task] = None
111
+
112
+
113
+ def sanitize_for_json(obj):
114
+ """Recursively replace inf/nan floats with None for JSON serialization."""
115
+ import math
116
+ if isinstance(obj, float):
117
+ if math.isinf(obj) or math.isnan(obj):
118
+ return None
119
+ return obj
120
+ if isinstance(obj, dict):
121
+ return {k: sanitize_for_json(v) for k, v in obj.items()}
122
+ if isinstance(obj, (list, tuple)):
123
+ return [sanitize_for_json(v) for v in obj]
124
+ return obj
125
+
126
+ # ──────────────────────────────────────────────────────────────
127
+ # LM Studio helpers
128
+ # ──────────────────────────────────────────────────────────────
129
+
130
+ def detect_lms_cli() -> str:
131
+ """Find the lms CLI binary."""
132
+ candidates = [
133
+ Path.home() / ".lmstudio" / "bin" / "lms",
134
+ Path("/usr/local/bin/lms"),
135
+ ]
136
+ for c in candidates:
137
+ if c.exists():
138
+ return str(c)
139
+ return ""
140
+
141
+
142
+ def lms_run(args: list[str], timeout: int = 30) -> tuple[int, str]:
143
+ """Run an lms CLI command and return (returncode, output)."""
144
+ lms = config.lms_cli_path or detect_lms_cli()
145
+ if not lms:
146
+ return -1, "lms CLI not found"
147
+ try:
148
+ result = subprocess.run(
149
+ [lms] + args,
150
+ capture_output=True, text=True, timeout=timeout)
151
+ return result.returncode, result.stdout + result.stderr
152
+ except subprocess.TimeoutExpired:
153
+ return -1, "lms command timed out"
154
+ except Exception as e:
155
+ return -1, str(e)
156
+
157
+
158
+ def detect_loaded_model() -> dict:
159
+ """Query LM Studio for currently loaded model.
160
+
161
+ Returns dict with: key, path, architecture, or empty dict if none.
162
+ """
163
+ rc, output = lms_run(["ps", "--json"])
164
+ if rc != 0:
165
+ # Try without --json
166
+ rc, output = lms_run(["ps"])
167
+ if rc != 0:
168
+ return {}
169
+
170
+ try:
171
+ data = json.loads(output)
172
+ if isinstance(data, list) and len(data) > 0:
173
+ model = data[0]
174
+ return {
175
+ "key": model.get("identifier", model.get("id", "")),
176
+ "path": model.get("path", ""),
177
+ "architecture": model.get("architecture", ""),
178
+ }
179
+ except json.JSONDecodeError:
180
+ # Parse text output
181
+ lines = output.strip().split("\n")
182
+ for line in lines:
183
+ line = line.strip()
184
+ if line and not line.startswith("─") and "No models" not in line:
185
+ return {"key": line.split()[0] if line.split() else "", "path": "", "architecture": ""}
186
+
187
+ return {}
188
+
189
+
190
+ def resolve_model_path(model_key: str) -> str:
191
+ """Resolve GGUF file path from model key using lms ls."""
192
+ rc, output = lms_run(["ls", "--json"])
193
+ if rc != 0:
194
+ # Fallback: search common paths
195
+ lms_models = Path.home() / ".lmstudio" / "models"
196
+ for gguf in lms_models.rglob("*.gguf"):
197
+ if model_key.replace("-", "").lower() in str(gguf).replace("-", "").lower():
198
+ return str(gguf)
199
+ return ""
200
+
201
+ try:
202
+ data = json.loads(output)
203
+ for model in (data if isinstance(data, list) else []):
204
+ if model.get("identifier", "") == model_key or model.get("id", "") == model_key:
205
+ return model.get("path", "")
206
+ except json.JSONDecodeError:
207
+ pass
208
+
209
+ return ""
210
+
211
+
212
+ def unload_lms_model(model_key: str) -> bool:
213
+ """Unload model from LM Studio to free memory."""
214
+ rc, output = lms_run(["unload", model_key])
215
+ if rc == 0:
216
+ log.info(f"Unloaded {model_key} from LM Studio")
217
+ return True
218
+ log.warning(f"Failed to unload {model_key}: {output}")
219
+ return False
220
+
221
+
222
+ def load_lms_model(model_key: str) -> bool:
223
+ """Load model into LM Studio."""
224
+ rc, output = lms_run(["load", model_key], timeout=120)
225
+ if rc == 0:
226
+ log.info(f"Loaded {model_key} into LM Studio")
227
+ return True
228
+ log.warning(f"Failed to load {model_key}: {output}")
229
+ return False
230
+
231
+
232
+ # ──────────────────────────────────────────────────────────────
233
+ # Model loading helpers
234
+ # ──────────────────────────────────────────────────────────────
235
+
236
+ def detect_model_params(model_path: str) -> dict:
237
+ """Detect model parameters (layers, dim, vocab) from config files.
238
+
239
+ Looks for config.json in the model directory or HuggingFace cache.
240
+ """
241
+ model_dir = Path(model_path).parent
242
+ candidates = [
243
+ model_dir / "config.json",
244
+ model_dir / "params.json",
245
+ ]
246
+
247
+ for cfg_path in candidates:
248
+ if cfg_path.exists():
249
+ with open(cfg_path) as f:
250
+ cfg = json.load(f)
251
+ return {
252
+ "n_layers": cfg.get("num_hidden_layers", cfg.get("n_layers", 32)),
253
+ "dim": cfg.get("hidden_size", cfg.get("dim", 3584)),
254
+ "vocab_size": cfg.get("vocab_size", 151936),
255
+ "architecture": cfg.get("model_type", cfg.get("architectures", [""])[0] if cfg.get("architectures") else ""),
256
+ }
257
+
258
+ # Try reading GGUF metadata for model params
259
+ gguf_file = Path(model_path)
260
+ if not gguf_file.is_absolute():
261
+ gguf_file = Path.home() / ".lmstudio" / "models" / model_path
262
+ if gguf_file.exists() and gguf_file.suffix == ".gguf":
263
+ try:
264
+ params = _read_gguf_metadata(str(gguf_file))
265
+ if params:
266
+ return params
267
+ except Exception as e:
268
+ log.warning(f"GGUF metadata read failed: {e}")
269
+
270
+ # Default values for common architectures
271
+ log.warning(f"No config.json found in {model_dir}, using defaults")
272
+ return {
273
+ "n_layers": 32,
274
+ "dim": 3584,
275
+ "vocab_size": 151936,
276
+ "architecture": "qwen2",
277
+ }
278
+
279
+
280
+ def _read_gguf_metadata(gguf_path: str) -> Optional[dict]:
281
+ """Read model parameters from GGUF file metadata."""
282
+ import struct
283
+
284
+ with open(gguf_path, "rb") as f:
285
+ magic = f.read(4)
286
+ if magic != b"GGUF":
287
+ return None
288
+
289
+ version = struct.unpack("<I", f.read(4))[0]
290
+ n_tensors = struct.unpack("<Q", f.read(8))[0]
291
+ n_kv = struct.unpack("<Q", f.read(8))[0]
292
+
293
+ metadata = {}
294
+ for _ in range(n_kv):
295
+ # Read key
296
+ key_len = struct.unpack("<Q", f.read(8))[0]
297
+ key = f.read(key_len).decode("utf-8")
298
+ # Read value type
299
+ vtype = struct.unpack("<I", f.read(4))[0]
300
+ # Read value based on type
301
+ if vtype == 4: # UINT32
302
+ val = struct.unpack("<I", f.read(4))[0]
303
+ elif vtype == 5: # INT32
304
+ val = struct.unpack("<i", f.read(4))[0]
305
+ elif vtype == 6: # FLOAT32
306
+ val = struct.unpack("<f", f.read(4))[0]
307
+ elif vtype == 8: # STRING
308
+ str_len = struct.unpack("<Q", f.read(8))[0]
309
+ val = f.read(str_len).decode("utf-8")
310
+ elif vtype == 10: # UINT64
311
+ val = struct.unpack("<Q", f.read(8))[0]
312
+ elif vtype == 7: # BOOL
313
+ val = struct.unpack("<?", f.read(1))[0]
314
+ elif vtype == 0: # UINT8
315
+ val = struct.unpack("<B", f.read(1))[0]
316
+ elif vtype == 1: # INT8
317
+ val = struct.unpack("<b", f.read(1))[0]
318
+ elif vtype == 2: # UINT16
319
+ val = struct.unpack("<H", f.read(2))[0]
320
+ elif vtype == 3: # INT16
321
+ val = struct.unpack("<h", f.read(2))[0]
322
+ elif vtype == 9: # ARRAY
323
+ arr_type = struct.unpack("<I", f.read(4))[0]
324
+ arr_len = struct.unpack("<Q", f.read(8))[0]
325
+ # Skip array data (we don't need it)
326
+ val = f"[array of {arr_len}]"
327
+ for _ in range(arr_len):
328
+ if arr_type == 8: # STRING array
329
+ s_len = struct.unpack("<Q", f.read(8))[0]
330
+ f.read(s_len)
331
+ elif arr_type in (4, 5, 6):
332
+ f.read(4)
333
+ elif arr_type in (10,):
334
+ f.read(8)
335
+ elif arr_type in (0, 1, 7):
336
+ f.read(1)
337
+ elif arr_type in (2, 3):
338
+ f.read(2)
339
+ elif vtype == 12: # FLOAT64
340
+ val = struct.unpack("<d", f.read(8))[0]
341
+ elif vtype == 11: # INT64
342
+ val = struct.unpack("<q", f.read(8))[0]
343
+ else:
344
+ break # Unknown type, stop parsing
345
+
346
+ metadata[key] = val
347
+
348
+ # Extract model params from GGUF metadata keys
349
+ n_layers = metadata.get("qwen2.block_count",
350
+ metadata.get("llama.block_count",
351
+ metadata.get("block_count", 32)))
352
+ dim = metadata.get("qwen2.embedding_length",
353
+ metadata.get("llama.embedding_length",
354
+ metadata.get("embedding_length", 3584)))
355
+ vocab_size = metadata.get("qwen2.vocab_size",
356
+ metadata.get("llama.vocab_size",
357
+ metadata.get("tokenizer.ggml.tokens", "[array of")))
358
+ if isinstance(vocab_size, str):
359
+ vocab_size = 151936 # Default
360
+
361
+ arch = metadata.get("general.architecture", "qwen2")
362
+
363
+ log.info(f"GGUF metadata: arch={arch}, layers={n_layers}, dim={dim}, vocab={vocab_size}")
364
+ return {
365
+ "n_layers": n_layers,
366
+ "dim": dim,
367
+ "vocab_size": vocab_size,
368
+ "architecture": arch,
369
+ }
370
+
371
+
372
+ # Known mappings from GGUF architecture/size to HuggingFace repos
373
+ _HF_MODEL_MAP = {
374
+ # Qwen3.5 family (Mamba hybrid — model.train()/eval() enables LoRA training)
375
+ ("qwen2", 2048, 24): "Qwen/Qwen3.5-2B-Base", # 2B (Mamba)
376
+ ("qwen2", 3584, 32): "Qwen/Qwen3.5-0.8B", # 0.8B (Mamba)
377
+ ("qwen2", 3584, 36): "Qwen/Qwen3.5-3B", # 3B (Mamba)
378
+ ("qwen2", 4096, 40): "Qwen/Qwen3.5-9B", # 9B (Mamba)
379
+ ("qwen2", 5120, 40): "Qwen/Qwen3.5-9B", # 9B (alt dim)
380
+ # Qwen2.5 family (standard transformer — full LoRA training support)
381
+ ("qwen2", 1536, 28): "Qwen/Qwen2.5-1.5B-Instruct",
382
+ ("qwen2", 2048, 36): "Qwen/Qwen2.5-3B-Instruct",
383
+ ("qwen2", 3584, 28): "Qwen/Qwen2.5-7B-Instruct",
384
+ # Qwen3 family
385
+ ("qwen3", 2048, 28): "Qwen/Qwen3-0.6B",
386
+ ("qwen3", 3584, 36): "Qwen/Qwen3-4B",
387
+ ("qwen3", 4096, 32): "Qwen/Qwen3-8B",
388
+ # Llama family
389
+ ("llama", 4096, 32): "meta-llama/Llama-3.2-3B-Instruct",
390
+ ("llama", 4096, 40): "meta-llama/Llama-3.1-8B-Instruct",
391
+ }
392
+
393
+
394
+ def _resolve_hf_repo(model_key: str, architecture: str, dim: int, n_layers: int) -> str:
395
+ """Resolve HuggingFace repo name from model architecture/size.
396
+
397
+ MLX needs HF-format weights (safetensors + config.json), not GGUF.
398
+ We map the GGUF model's architecture to its HF base model.
399
+ """
400
+ # Check explicit mapping
401
+ key = (architecture, dim, n_layers)
402
+ if key in _HF_MODEL_MAP:
403
+ repo = _HF_MODEL_MAP[key]
404
+ log.info(f"Resolved HF repo: {model_key} → {repo} (via arch map)")
405
+ return repo
406
+
407
+ # Try to infer from model key name
408
+ name = model_key.lower()
409
+ if "qwen3.5" in name:
410
+ if "0.8b" in name or "0.6b" in name:
411
+ return "Qwen/Qwen3.5-0.8B"
412
+ elif "2b" in name:
413
+ return "Qwen/Qwen3.5-2B-Base"
414
+ elif "3b" in name:
415
+ return "Qwen/Qwen3.5-3B"
416
+ elif "9b" in name:
417
+ return "Qwen/Qwen3.5-9B"
418
+ elif "27b" in name:
419
+ return "Qwen/Qwen3.5-27B"
420
+ elif "qwen3" in name:
421
+ if "0.6b" in name:
422
+ return "Qwen/Qwen3-0.6B"
423
+ elif "4b" in name:
424
+ return "Qwen/Qwen3-4B"
425
+ elif "8b" in name:
426
+ return "Qwen/Qwen3-8B"
427
+ elif "llama" in name:
428
+ if "8b" in name:
429
+ return "meta-llama/Llama-3.1-8B-Instruct"
430
+ elif "3b" in name:
431
+ return "meta-llama/Llama-3.2-3B-Instruct"
432
+
433
+ # Fallback: try the model_key as-is (might be a HF repo)
434
+ log.warning(f"Could not resolve HF repo for {model_key} (arch={architecture}, "
435
+ f"dim={dim}, layers={n_layers}). Trying key as-is.")
436
+ return model_key
437
+
438
+
439
+ # ──────────────────────────────────────────────────────────────
440
+ # FastAPI app
441
+ # ──────────────────────────────────────────────────────────────
442
+
443
+ app = FastAPI(title="JARVIS Neural Engine Daemon", version="0.1.0")
444
+
445
+ app.add_middleware(
446
+ CORSMiddleware,
447
+ allow_origins=["*"],
448
+ allow_methods=["*"],
449
+ allow_headers=["*"],
450
+ )
451
+
452
+
453
+ @app.get("/status")
454
+ async def get_status():
455
+ """Return daemon state and training metrics."""
456
+ result = {**daemon_state}
457
+
458
+ if mlx_trainer:
459
+ result.update(mlx_trainer.stats())
460
+ elif trainer:
461
+ result.update(trainer.stats())
462
+
463
+ if data_mgr:
464
+ result["buffer"] = data_mgr.stats()
465
+
466
+ result["mlx_available"] = MLX_AVAILABLE
467
+ result["uptime_sec"] = round(time.time() - daemon_state["startup_time"])
468
+
469
+ return sanitize_for_json(result)
470
+
471
+
472
+ @app.get("/config")
473
+ async def get_config():
474
+ """Return current configuration."""
475
+ return config.to_dict()
476
+
477
+
478
+ @app.put("/config")
479
+ async def update_config(request: Request):
480
+ """Update configuration parameters."""
481
+ data = await request.json()
482
+ config.update_from_dict(data)
483
+
484
+ # Apply LR change to active trainer
485
+ if "learning_rate" in data:
486
+ if mlx_trainer:
487
+ mlx_trainer.update_learning_rate(data["learning_rate"])
488
+ elif trainer:
489
+ trainer.update_learning_rate(data["learning_rate"])
490
+
491
+ return {"ok": True, "config": config.to_dict()}
492
+
493
+
494
+ @app.post("/activate")
495
+ async def activate(request: Request):
496
+ """Activate neural adaptation: detect LMS model → unload → load MLX → init ANE.
497
+
498
+ Optional body:
499
+ {"model_key": "...", "model_path": "..."} to override LM Studio detection.
500
+ {"hf_repo": "Qwen/Qwen3.5-2B-Base"} to load directly from HuggingFace (no GGUF needed).
501
+ """
502
+ global trainer, mlx_trainer, data_mgr, mlx_model, mlx_tokenizer
503
+
504
+ if daemon_state["active"]:
505
+ raise HTTPException(400, "Already active")
506
+
507
+ try:
508
+ body = await request.json()
509
+ except Exception:
510
+ body = {}
511
+
512
+ # ── Direct HF model loading (no LM Studio GGUF required) ──────────
513
+ hf_repo = body.get("hf_repo", "")
514
+ if hf_repo and MLX_AVAILABLE:
515
+ log.info(f"Direct HF activation: {hf_repo}")
516
+ try:
517
+ mlx_model, mlx_tokenizer = mlx_lm.load(hf_repo)
518
+ except Exception as e:
519
+ raise HTTPException(500, f"Failed to load HF model {hf_repo}: {e}")
520
+
521
+ # Detect params from loaded model
522
+ import mlx.utils as mlx_utils_mod
523
+ layers = None
524
+ for path_fn in [lambda m: m.model.layers, lambda m: m.layers,
525
+ lambda m: m.language_model.model.layers]:
526
+ try:
527
+ layers = path_fn(mlx_model)
528
+ if isinstance(layers, list) and len(layers) > 0:
529
+ break
530
+ except (AttributeError, TypeError):
531
+ continue
532
+ n_layers = len(layers) if layers else 24
533
+ # Get dim from first linear layer
534
+ dim = 2048
535
+ if layers:
536
+ for name, p in mlx_utils_mod.tree_flatten(layers[0].parameters()):
537
+ if "proj" in name and "weight" in name:
538
+ dim = max(p.shape)
539
+ break
540
+ vocab_size = 151936 # Default
541
+ model_key = hf_repo
542
+ model_path = ""
543
+ architecture = "hf_direct"
544
+
545
+ config.model_key = model_key
546
+ config.model_path = model_path
547
+ config.model_architecture = architecture
548
+ config.resolve_paths()
549
+ config.ensure_dirs()
550
+
551
+ # Skip to trainer initialization (step 6)
552
+ # (no LM Studio unload needed)
553
+
554
+ else:
555
+ # ── Standard LM Studio flow ──────────────────────────────────
556
+ # 1. Detect model from LM Studio
557
+ model_key = body.get("model_key", "")
558
+ model_path = body.get("model_path", "")
559
+
560
+ if not model_key:
561
+ detected = detect_loaded_model()
562
+ if not detected:
563
+ raise HTTPException(404, "No model loaded in LM Studio")
564
+ model_key = detected["key"]
565
+ model_path = detected.get("path", "")
566
+ log.info(f"Detected LM Studio model: {model_key}")
567
+
568
+ if not model_path:
569
+ model_path = resolve_model_path(model_key)
570
+
571
+ if not model_path:
572
+ raise HTTPException(404, f"Could not resolve path for {model_key}")
573
+
574
+ log.info(f"Model path: {model_path}")
575
+
576
+ # 2. Detect model parameters
577
+ params = detect_model_params(model_path)
578
+ n_layers = params["n_layers"]
579
+ dim = params["dim"]
580
+ vocab_size = params["vocab_size"]
581
+
582
+ # Validate dim is multiple of 16 for ANE
583
+ if dim % 16 != 0:
584
+ raise HTTPException(400, f"Model dim={dim} not a multiple of 16 (ANE requirement)")
585
+
586
+ # 3. Update config
587
+ config.model_key = model_key
588
+ config.model_path = model_path
589
+ config.model_architecture = params["architecture"]
590
+ config.resolve_paths()
591
+ config.ensure_dirs()
592
+
593
+ # 4. Unload from LM Studio
594
+ if not body.get("skip_unload", False):
595
+ unload_lms_model(model_key)
596
+
597
+ # 5. Load into MLX (if available)
598
+ if MLX_AVAILABLE and not body.get("skip_mlx", False):
599
+ try:
600
+ # MLX needs HuggingFace-format weights (safetensors + config.json),
601
+ # not GGUF files. Resolve the HF base model repo from the architecture.
602
+ hf_repo = _resolve_hf_repo(model_key, params["architecture"], dim, n_layers)
603
+ log.info(f"Loading model into MLX from HuggingFace: {hf_repo}...")
604
+ mlx_model, mlx_tokenizer = mlx_lm.load(hf_repo)
605
+ log.info("MLX model loaded")
606
+ except Exception as e:
607
+ log.warning(f"MLX load failed (inference unavailable): {e}")
608
+ mlx_model = None
609
+ mlx_tokenizer = None
610
+
611
+ # 6. Initialize trainer (MLX preferred, ANE fallback)
612
+ if config.training_backend == "mlx" and MLX_AVAILABLE and MLX_LORA_AVAILABLE and mlx_model is not None:
613
+ log.info("Initializing MLX LoRA trainer (real autograd)")
614
+ mlx_trainer = MLXLoRATrainer(mlx_model, mlx_tokenizer, config)
615
+ # Try to load existing adapter
616
+ if Path(config.adapter_dir).exists():
617
+ mlx_trainer.load_adapter()
618
+ trainer = None # Don't use ANE trainer
619
+ else:
620
+ log.info("Initializing ANE LoRA trainer (legacy)")
621
+ trainer = ANELoRATrainer(config)
622
+ trainer.initialize(n_layers, dim)
623
+ # Try to load existing adapter
624
+ if Path(config.adapter_dir).exists():
625
+ trainer.load_adapter()
626
+ mlx_trainer = None
627
+
628
+ # 7. Initialize data manager
629
+ data_mgr = TrainingDataManager(
630
+ rolling_size=config.rolling_buffer_size,
631
+ replay_size=config.replay_buffer_size,
632
+ replay_path=config.replay_path,
633
+ min_response_tokens=config.min_response_tokens,
634
+ )
635
+
636
+ # 8. Update state
637
+ arch = architecture if hf_repo else params["architecture"]
638
+ daemon_state.update({
639
+ "active": True,
640
+ "model_key": model_key,
641
+ "model_path": model_path,
642
+ "architecture": arch,
643
+ "n_layers": n_layers,
644
+ "dim": dim,
645
+ "vocab_size": vocab_size,
646
+ "error": "",
647
+ })
648
+
649
+ log.info(f"Neural adaptation ACTIVATED: {model_key} "
650
+ f"({n_layers}L, dim={dim}, vocab={vocab_size})")
651
+
652
+ active_trainer = mlx_trainer or trainer
653
+ return sanitize_for_json({
654
+ "ok": True,
655
+ "model_key": model_key,
656
+ "architecture": arch,
657
+ "n_layers": n_layers,
658
+ "dim": dim,
659
+ "params": active_trainer.stats() if active_trainer else {},
660
+ })
661
+
662
+
663
+ @app.post("/deactivate")
664
+ async def deactivate(request: Request):
665
+ """Deactivate: save adapter → optionally export GGUF → reload LM Studio."""
666
+ global trainer, mlx_trainer, data_mgr, mlx_model, mlx_tokenizer
667
+
668
+ if not daemon_state["active"]:
669
+ raise HTTPException(400, "Not active")
670
+
671
+ try:
672
+ body = await request.json()
673
+ except Exception:
674
+ body = {}
675
+
676
+ # Save current state from active trainer
677
+ if mlx_trainer:
678
+ mlx_trainer.save_adapter()
679
+ elif trainer:
680
+ trainer.save_adapter()
681
+
682
+ if data_mgr:
683
+ data_mgr.save_replay()
684
+ data_mgr.save_rolling()
685
+
686
+ # Export to GGUF if requested
687
+ if body.get("export_gguf", False) and MLX_AVAILABLE:
688
+ try:
689
+ from export_to_lms import export_adapter_to_lms
690
+ export_adapter_to_lms(config)
691
+ except Exception as e:
692
+ log.warning(f"GGUF export failed: {e}")
693
+
694
+ # Cleanup MLX trainer
695
+ if mlx_trainer:
696
+ mlx_trainer.cleanup()
697
+ mlx_trainer = None
698
+
699
+ # Cleanup MLX model
700
+ mlx_model = None
701
+ mlx_tokenizer = None
702
+
703
+ # Cleanup ANE trainer
704
+ if trainer:
705
+ trainer.cleanup()
706
+ trainer = None
707
+ data_mgr = None
708
+
709
+ # Reload into LM Studio
710
+ model_key = daemon_state["model_key"]
711
+ if model_key and not body.get("skip_reload", False):
712
+ load_lms_model(model_key)
713
+
714
+ daemon_state.update({
715
+ "active": False,
716
+ "training": False,
717
+ "error": "",
718
+ })
719
+
720
+ log.info(f"Neural adaptation DEACTIVATED")
721
+ return {"ok": True}
722
+
723
+
724
+ def _collect_and_train(user_text: str, messages: list, collected_text: str):
725
+ """Collect training data from a chat turn and schedule background training."""
726
+ if not collected_text or not data_mgr:
727
+ return
728
+
729
+ system_prompt = ""
730
+ for m in messages:
731
+ if m.get("role") == "system":
732
+ system_prompt = m.get("content", "")
733
+ break
734
+
735
+ accepted = data_mgr.add_turn(
736
+ user_text=user_text,
737
+ assistant_text=collected_text,
738
+ system_prompt=system_prompt,
739
+ )
740
+ log.info(f"Training data collected: {len(collected_text)} chars, accepted={accepted}")
741
+
742
+ if accepted and config.auto_train and (mlx_trainer or trainer):
743
+ try:
744
+ loop = asyncio.get_event_loop()
745
+ loop.call_soon_threadsafe(
746
+ lambda: asyncio.ensure_future(run_background_training()))
747
+ except RuntimeError:
748
+ log.warning("Could not schedule background training (no event loop)")
749
+
750
+
751
+ @app.post("/chat")
752
+ async def chat(request: Request):
753
+ """Chat endpoint: MLX inference with live LoRA adapter.
754
+
755
+ Streams response as SSE (text/event-stream).
756
+ After response completes, auto-triggers background training if enabled.
757
+ """
758
+ if not daemon_state["active"]:
759
+ raise HTTPException(400, "Not active — call /activate first")
760
+
761
+ body = await request.json()
762
+ messages = body.get("messages", [])
763
+ temperature = body.get("temperature", 0.7)
764
+ max_tokens = body.get("max_tokens", 2048)
765
+ stream = body.get("stream", True)
766
+
767
+ if not messages:
768
+ raise HTTPException(400, "No messages provided")
769
+
770
+ if not MLX_AVAILABLE or mlx_model is None:
771
+ raise HTTPException(503, "MLX not available — inference requires mlx-lm")
772
+
773
+ # Format prompt
774
+ if mlx_tokenizer and hasattr(mlx_tokenizer, 'apply_chat_template'):
775
+ prompt = mlx_tokenizer.apply_chat_template(
776
+ messages, tokenize=False, add_generation_prompt=True)
777
+ else:
778
+ # Simple fallback
779
+ prompt = "\n".join(
780
+ f"{m['role']}: {m['content']}" for m in messages)
781
+ prompt += "\nassistant:"
782
+
783
+ user_text = messages[-1]["content"] if messages else ""
784
+
785
+ # Shared state for post-stream training data collection
786
+ _collected = {"text": ""}
787
+
788
+ async def generate_stream():
789
+ """Generate tokens via MLX and stream as SSE."""
790
+ import queue
791
+ import threading
792
+
793
+ token_queue: queue.Queue = queue.Queue()
794
+
795
+ def _mlx_generate():
796
+ """Run MLX generation in a thread (it's synchronous/blocking)."""
797
+ try:
798
+ with _gpu_lock:
799
+ # Ensure eval mode for inference (fast Metal kernels for Mamba)
800
+ mlx_model.eval()
801
+ for response in mlx_lm.stream_generate(
802
+ mlx_model, mlx_tokenizer, prompt,
803
+ max_tokens=max_tokens,
804
+ ):
805
+ token_queue.put(("token", response.text, response.finish_reason))
806
+ token_queue.put(("done", None, None))
807
+ except Exception as e:
808
+ token_queue.put(("error", str(e), None))
809
+
810
+ thread = threading.Thread(target=_mlx_generate, daemon=True)
811
+ thread.start()
812
+
813
+ try:
814
+ while True:
815
+ # Poll queue without blocking the event loop
816
+ try:
817
+ kind, data, finish = token_queue.get(timeout=0.05)
818
+ except queue.Empty:
819
+ await asyncio.sleep(0.01)
820
+ continue
821
+
822
+ if kind == "token":
823
+ _collected["text"] += data
824
+ event = json.dumps({
825
+ "choices": [{
826
+ "delta": {"content": data},
827
+ "finish_reason": finish,
828
+ }]
829
+ })
830
+ yield f"data: {event}\n\n"
831
+ elif kind == "done":
832
+ break
833
+ elif kind == "error":
834
+ log.error(f"Generation error: {data}")
835
+ yield f"data: {json.dumps({'error': data})}\n\n"
836
+ yield "data: [DONE]\n\n"
837
+ return
838
+
839
+ # Final event
840
+ yield f"data: {json.dumps({'choices': [{'delta': {}, 'finish_reason': 'stop'}]})}\n\n"
841
+ yield "data: [DONE]\n\n"
842
+
843
+ except GeneratorExit:
844
+ # StreamingResponse closing the generator — normal cleanup
845
+ log.info(f"Stream closed, collected {len(_collected['text'])} chars")
846
+ return
847
+ except Exception as e:
848
+ log.error(f"Generation error: {e}")
849
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
850
+ yield "data: [DONE]\n\n"
851
+ return
852
+ finally:
853
+ # Always collect training data after stream ends
854
+ _collect_and_train(user_text, messages, _collected["text"])
855
+
856
+ if stream:
857
+ return StreamingResponse(
858
+ generate_stream(),
859
+ media_type="text/event-stream",
860
+ headers={
861
+ "Cache-Control": "no-cache",
862
+ "Connection": "keep-alive",
863
+ "X-Accel-Buffering": "no",
864
+ },
865
+ )
866
+ else:
867
+ # Non-streaming: collect full response
868
+ full_text = ""
869
+ async for chunk in generate_stream():
870
+ if chunk.startswith("data: ") and "[DONE]" not in chunk:
871
+ try:
872
+ data = json.loads(chunk[6:])
873
+ delta = data.get("choices", [{}])[0].get("delta", {})
874
+ full_text += delta.get("content", "")
875
+ except Exception:
876
+ pass
877
+
878
+ return {
879
+ "choices": [{
880
+ "message": {"role": "assistant", "content": full_text},
881
+ "finish_reason": "stop",
882
+ }]
883
+ }
884
+
885
+
886
+ _pending_train_epochs: int = 0 # Set by /train endpoint for manual training
887
+
888
+
889
+ async def run_background_training(epochs: int = 0):
890
+ """Run a training cycle in a thread (GPU-bound, would block event loop).
891
+
892
+ Args:
893
+ epochs: Number of epochs. 0 = use config.epochs_per_cycle (auto-train).
894
+ """
895
+ if daemon_state["training"]:
896
+ return # Already training
897
+
898
+ if not (mlx_trainer or trainer) or not data_mgr:
899
+ return
900
+
901
+ global _pending_train_epochs
902
+ _pending_train_epochs = epochs
903
+
904
+ loop = asyncio.get_event_loop()
905
+ await loop.run_in_executor(None, _training_worker)
906
+
907
+
908
+ def _training_worker():
909
+ """Synchronous training worker — runs in thread pool."""
910
+ global _pending_train_epochs
911
+
912
+ if not (mlx_trainer or trainer) or not data_mgr:
913
+ return
914
+
915
+ daemon_state["training"] = True
916
+ start = time.time()
917
+
918
+ # Determine epochs: manual override or config default
919
+ epochs = _pending_train_epochs if _pending_train_epochs > 0 else config.epochs_per_cycle
920
+ _pending_train_epochs = 0
921
+
922
+ log.info(f"Training worker started (epochs={epochs})")
923
+
924
+ try:
925
+ batch = data_mgr.get_training_batch(
926
+ batch_size=config.batch_size,
927
+ replay_ratio=config.replay_ratio,
928
+ )
929
+
930
+ if not batch:
931
+ log.info("Training worker: no batch data available")
932
+ return
933
+
934
+ log.info(f"Training worker: got {len(batch)} examples, {epochs} epoch(s)")
935
+
936
+ # ── MLX trainer (real autograd) ──────────────────────────
937
+ if mlx_trainer:
938
+ with _gpu_lock:
939
+ result = mlx_trainer.run_training_cycle(batch, epochs=epochs)
940
+ log.info(f"MLX training result: {result}")
941
+
942
+ # Auto-save
943
+ if (result.get("trained") and config.auto_save_interval > 0 and
944
+ mlx_trainer.total_cycles % config.auto_save_interval == 0):
945
+ with _gpu_lock:
946
+ mlx_trainer.save_adapter()
947
+ mlx_trainer.adapter_version += 1
948
+ return
949
+
950
+ # ── ANE trainer (legacy fallback) ────────────────────────
951
+ if not trainer:
952
+ return
953
+
954
+ total_loss = 0
955
+ n_examples = 0
956
+
957
+ for ex_idx, example in enumerate(batch):
958
+ if mlx_tokenizer is None:
959
+ continue
960
+
961
+ text = ""
962
+ for msg in example.messages:
963
+ text += f"{msg['role']}: {msg['content']}\n"
964
+
965
+ tokens = mlx_tokenizer.encode(text)
966
+ if len(tokens) < 2:
967
+ continue
968
+
969
+ seq_len = min(len(tokens) - 1, config.max_seq_len)
970
+ ane_seq = config.ane_seq_len
971
+ if seq_len > ane_seq:
972
+ tokens_trimmed = tokens[seq_len - ane_seq : seq_len + 1]
973
+ else:
974
+ tokens_trimmed = tokens[:ane_seq + 1]
975
+
976
+ input_ids = np.array(tokens_trimmed[:ane_seq], dtype=np.int32)
977
+ target_ids = np.array(tokens_trimmed[1:ane_seq + 1], dtype=np.int32)
978
+
979
+ if len(input_ids) < ane_seq:
980
+ input_ids = np.pad(input_ids, (0, ane_seq - len(input_ids)))
981
+ target_ids = np.pad(target_ids, (0, ane_seq - len(target_ids)))
982
+
983
+ dim = daemon_state["dim"]
984
+ n_layers = daemon_state["n_layers"]
985
+ activations = [
986
+ np.random.randn(1, dim, 1, ane_seq).astype(np.float32) * 0.01
987
+ for _ in range(n_layers)
988
+ ]
989
+
990
+ vocab = daemon_state["vocab_size"]
991
+ if mlx_tokenizer and hasattr(mlx_tokenizer, 'vocab_size'):
992
+ vocab = max(vocab, mlx_tokenizer.vocab_size)
993
+ max_token_id = max(int(target_ids.max()), int(input_ids.max()))
994
+ if max_token_id >= vocab:
995
+ vocab = max_token_id + 1
996
+ logits = np.random.randn(vocab, ane_seq).astype(np.float32)
997
+
998
+ for step in range(config.steps_per_cycle):
999
+ loss = trainer.train_step(activations, logits, target_ids)
1000
+ total_loss += loss
1001
+ n_examples += 1
1002
+
1003
+ if n_examples > 0:
1004
+ avg_loss = total_loss / n_examples
1005
+ trainer.last_loss = avg_loss
1006
+ trainer.total_cycles += 1
1007
+
1008
+ elapsed = time.time() - start
1009
+ log.info(f"Training cycle {trainer.total_cycles}: "
1010
+ f"loss={avg_loss:.4f}, {n_examples} steps, "
1011
+ f"{elapsed:.1f}s")
1012
+
1013
+ if (config.auto_save_interval > 0 and
1014
+ trainer.total_cycles % config.auto_save_interval == 0):
1015
+ trainer.save_adapter()
1016
+ trainer.adapter_version += 1
1017
+
1018
+ except Exception as e:
1019
+ log.error(f"Background training error: {e}")
1020
+ import traceback
1021
+ traceback.print_exc()
1022
+ finally:
1023
+ daemon_state["training"] = False
1024
+ daemon_state["last_train_time"] = time.time()
1025
+
1026
+
1027
+ @app.post("/train")
1028
+ async def manual_train(request: Request):
1029
+ """Manually trigger a training cycle.
1030
+
1031
+ Optional body: {"messages": [...]} to inject training data before training.
1032
+ Accepts a list of message pairs [{"role":"user","content":"..."},{"role":"assistant","content":"..."}]
1033
+ or a list of such pairs for batch injection.
1034
+ """
1035
+ if not daemon_state["active"]:
1036
+ raise HTTPException(400, "Not active")
1037
+ if not (mlx_trainer or trainer):
1038
+ raise HTTPException(500, "Trainer not initialized")
1039
+
1040
+ if daemon_state["training"]:
1041
+ return {"ok": False, "message": "Training already in progress"}
1042
+
1043
+ try:
1044
+ body = await request.json()
1045
+ except Exception:
1046
+ body = {}
1047
+
1048
+ # Inject training data if provided
1049
+ injected = 0
1050
+ if "messages" in body and data_mgr:
1051
+ pairs = body["messages"]
1052
+ # Support single pair or list of pairs
1053
+ if pairs and isinstance(pairs[0], dict):
1054
+ pairs = [pairs] # Wrap single pair
1055
+ for msgs in pairs:
1056
+ user_text = ""
1057
+ assistant_text = ""
1058
+ system_prompt = ""
1059
+ for m in msgs:
1060
+ if m.get("role") == "user":
1061
+ user_text = m.get("content", "")
1062
+ elif m.get("role") == "assistant":
1063
+ assistant_text = m.get("content", "")
1064
+ elif m.get("role") == "system":
1065
+ system_prompt = m.get("content", "")
1066
+ if user_text and assistant_text:
1067
+ accepted = data_mgr.add_turn(
1068
+ user_text=user_text,
1069
+ assistant_text=assistant_text,
1070
+ system_prompt=system_prompt,
1071
+ )
1072
+ if accepted:
1073
+ injected += 1
1074
+
1075
+ # Determine epochs: explicit param, or config.train_epochs for injected data, or config.epochs_per_cycle
1076
+ epochs = body.get("epochs", 0)
1077
+ if epochs <= 0:
1078
+ epochs = config.train_epochs if injected > 0 else config.epochs_per_cycle
1079
+
1080
+ asyncio.create_task(run_background_training(epochs=epochs))
1081
+ return {"ok": True, "message": f"Training started ({epochs} epochs)", "injected": injected, "epochs": epochs}
1082
+
1083
+
1084
+ @app.post("/save")
1085
+ async def save_adapter():
1086
+ """Save current adapter to disk."""
1087
+ active_trainer = mlx_trainer or trainer
1088
+ if not active_trainer:
1089
+ raise HTTPException(400, "No trainer active")
1090
+
1091
+ active_trainer.save_adapter()
1092
+ active_trainer.adapter_version += 1
1093
+
1094
+ if data_mgr:
1095
+ data_mgr.save_replay()
1096
+ data_mgr.save_rolling()
1097
+
1098
+ return {
1099
+ "ok": True,
1100
+ "version": active_trainer.adapter_version,
1101
+ "path": config.adapter_dir,
1102
+ }
1103
+
1104
+
1105
+ @app.post("/rollback")
1106
+ async def rollback(request: Request):
1107
+ """Load a previous adapter version."""
1108
+ active_trainer = mlx_trainer or trainer
1109
+ if not active_trainer:
1110
+ raise HTTPException(400, "No trainer active")
1111
+
1112
+ body = await request.json()
1113
+ version = body.get("version", None)
1114
+ path = body.get("path", "")
1115
+
1116
+ if not path:
1117
+ path = config.adapter_dir
1118
+
1119
+ if active_trainer.load_adapter(path):
1120
+ return sanitize_for_json({"ok": True, "stats": active_trainer.stats()})
1121
+ else:
1122
+ raise HTTPException(404, f"No adapter found at {path}")
1123
+
1124
+
1125
+ @app.get("/history")
1126
+ async def adapter_history():
1127
+ """List saved adapter versions."""
1128
+ base = Path(config.base_dir) / "adapters"
1129
+ if not base.exists():
1130
+ return {"versions": []}
1131
+
1132
+ versions = []
1133
+ for d in sorted(base.iterdir()):
1134
+ meta_path = d / "adapter_meta.json"
1135
+ if meta_path.exists():
1136
+ with open(meta_path) as f:
1137
+ meta = json.load(f)
1138
+ versions.append({
1139
+ "path": str(d),
1140
+ "version": meta.get("adapter_version", 0),
1141
+ "steps": meta.get("total_steps", 0),
1142
+ "loss": meta.get("last_loss", None),
1143
+ "timestamp": meta.get("timestamp", 0),
1144
+ })
1145
+
1146
+ return {"versions": versions}
1147
+
1148
+
1149
+ @app.post("/reset")
1150
+ async def reset_adapter(request: Request):
1151
+ """Reset adapter to initial (untrained) state.
1152
+
1153
+ Optional body: {"clear_data": true} to also clear training buffers.
1154
+ Default: clears both adapter AND data for a clean slate.
1155
+ """
1156
+ active_trainer = mlx_trainer or trainer
1157
+ if not active_trainer:
1158
+ raise HTTPException(400, "No trainer active")
1159
+
1160
+ try:
1161
+ body = await request.json()
1162
+ except Exception:
1163
+ body = {}
1164
+
1165
+ active_trainer.reset_adapter()
1166
+
1167
+ # Clear data buffers by default (opt-out with clear_data=false)
1168
+ if body.get("clear_data", True) and data_mgr:
1169
+ data_mgr.clear()
1170
+ log.info("Training data buffers cleared")
1171
+
1172
+ return sanitize_for_json({"ok": True, "stats": active_trainer.stats()})
1173
+
1174
+
1175
+ # ──────────────────────────────────────────────────────────────
1176
+ # Lifecycle
1177
+ # ──────────────────────────────────────────────────────────────
1178
+
1179
+ @app.on_event("startup")
1180
+ async def on_startup():
1181
+ log.info(f"Neural Engine Daemon starting on "
1182
+ f"{config.daemon_host}:{config.daemon_port}")
1183
+ log.info(f"MLX available: {MLX_AVAILABLE}")
1184
+
1185
+ config.resolve_paths()
1186
+ config.lms_cli_path = detect_lms_cli()
1187
+ if config.lms_cli_path:
1188
+ log.info(f"LM Studio CLI: {config.lms_cli_path}")
1189
+ else:
1190
+ log.warning("LM Studio CLI not found")
1191
+
1192
+
1193
+ @app.on_event("shutdown")
1194
+ async def on_shutdown():
1195
+ log.info("Shutting down...")
1196
+
1197
+ active_trainer = mlx_trainer or trainer
1198
+ if active_trainer:
1199
+ try:
1200
+ active_trainer.save_adapter()
1201
+ except Exception as e:
1202
+ log.error(f"Failed to save adapter on shutdown: {e}")
1203
+
1204
+ if data_mgr:
1205
+ try:
1206
+ data_mgr.save_replay()
1207
+ data_mgr.save_rolling()
1208
+ except Exception as e:
1209
+ log.error(f"Failed to save data on shutdown: {e}")
1210
+
1211
+ if active_trainer:
1212
+ active_trainer.cleanup()
1213
+
1214
+ log.info("Shutdown complete")
1215
+
1216
+
1217
+ def handle_signal(signum, frame):
1218
+ """Handle SIGTERM/SIGINT gracefully."""
1219
+ log.info(f"Received signal {signum}, initiating graceful shutdown...")
1220
+ active_trainer = mlx_trainer or trainer
1221
+ if active_trainer:
1222
+ try:
1223
+ active_trainer.save_adapter()
1224
+ except Exception:
1225
+ pass
1226
+ sys.exit(0)
1227
+
1228
+
1229
+ # ──────────────────────────────────────────────────────────────
1230
+ # Main
1231
+ # ──────────────────────────────────────────────────────────────
1232
+
1233
+ def main():
1234
+ signal.signal(signal.SIGTERM, handle_signal)
1235
+ signal.signal(signal.SIGINT, handle_signal)
1236
+
1237
+ # Load config from file if exists
1238
+ config_path = Path(config.base_dir).expanduser() / "config.json"
1239
+ if config_path.exists():
1240
+ try:
1241
+ loaded = NeuralConfig.load(str(config_path))
1242
+ for k, v in loaded.__dict__.items():
1243
+ setattr(config, k, v)
1244
+ log.info(f"Loaded config from {config_path}")
1245
+ except Exception as e:
1246
+ log.warning(f"Failed to load config: {e}")
1247
+
1248
+ config.resolve_paths()
1249
+
1250
+ # Override from env
1251
+ port = int(os.environ.get("NEURAL_DAEMON_PORT", config.daemon_port))
1252
+ host = os.environ.get("NEURAL_DAEMON_HOST", config.daemon_host)
1253
+
1254
+ log.info(f"Starting daemon on {host}:{port}")
1255
+
1256
+ uvicorn.run(
1257
+ app,
1258
+ host=host,
1259
+ port=port,
1260
+ log_level="info",
1261
+ access_log=False,
1262
+ )
1263
+
1264
+
1265
+ if __name__ == "__main__":
1266
+ main()
src/neural_data.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ neural_data.py — Training data manager for MLX LoRA fine-tuning.
3
+
4
+ Manages a rolling buffer of recent conversation turns and a persistent
5
+ replay buffer for anti-catastrophic-forgetting experience replay.
6
+ """
7
+
8
+ import json
9
+ import random
10
+ import time
11
+ from collections import deque
12
+ from pathlib import Path
13
+ from typing import Optional
14
+
15
+
16
+ class TrainingExample:
17
+ """A single training example (conversation turn)."""
18
+
19
+ __slots__ = ("messages", "timestamp", "token_count", "session_id")
20
+
21
+ def __init__(self, messages: list[dict], timestamp: float = 0,
22
+ token_count: int = 0, session_id: str = ""):
23
+ self.messages = messages
24
+ self.timestamp = timestamp or time.time()
25
+ self.token_count = token_count
26
+ self.session_id = session_id
27
+
28
+ def to_dict(self) -> dict:
29
+ return {
30
+ "messages": self.messages,
31
+ "timestamp": self.timestamp,
32
+ "token_count": self.token_count,
33
+ "session_id": self.session_id,
34
+ }
35
+
36
+ @classmethod
37
+ def from_dict(cls, d: dict) -> "TrainingExample":
38
+ return cls(
39
+ messages=d["messages"],
40
+ timestamp=d.get("timestamp", 0),
41
+ token_count=d.get("token_count", 0),
42
+ session_id=d.get("session_id", ""),
43
+ )
44
+
45
+
46
+ class TrainingDataManager:
47
+ """Manages rolling buffer + persistent replay for LoRA training."""
48
+
49
+ def __init__(self, rolling_size: int = 100, replay_size: int = 500,
50
+ replay_path: str = "", min_response_tokens: int = 10):
51
+ self.rolling_size = rolling_size
52
+ self.replay_size = replay_size
53
+ self.min_response_tokens = min_response_tokens
54
+ self.replay_path = replay_path
55
+
56
+ self._rolling: deque[TrainingExample] = deque(maxlen=rolling_size)
57
+ self._replay: list[TrainingExample] = []
58
+ self._total_added = 0
59
+
60
+ if replay_path:
61
+ self._load_replay()
62
+
63
+ @property
64
+ def rolling_count(self) -> int:
65
+ return len(self._rolling)
66
+
67
+ @property
68
+ def replay_count(self) -> int:
69
+ return len(self._replay)
70
+
71
+ @property
72
+ def total_added(self) -> int:
73
+ return self._total_added
74
+
75
+ def add_turn(self, user_text: str, assistant_text: str,
76
+ system_prompt: str = "", session_id: str = "") -> bool:
77
+ """Add a conversation turn to the training buffer.
78
+
79
+ Returns True if the example was accepted (not filtered).
80
+ """
81
+ # Quality filter: skip short/empty responses
82
+ approx_tokens = len(assistant_text.split())
83
+ if approx_tokens < self.min_response_tokens:
84
+ return False
85
+
86
+ # Skip tool-only or empty content
87
+ if not assistant_text.strip():
88
+ return False
89
+
90
+ messages = []
91
+ if system_prompt:
92
+ messages.append({"role": "system", "content": system_prompt})
93
+ messages.append({"role": "user", "content": user_text})
94
+ messages.append({"role": "assistant", "content": assistant_text})
95
+
96
+ example = TrainingExample(
97
+ messages=messages,
98
+ token_count=approx_tokens,
99
+ session_id=session_id,
100
+ )
101
+
102
+ self._rolling.append(example)
103
+ self._total_added += 1
104
+
105
+ # Add to replay with reservoir sampling
106
+ if len(self._replay) < self.replay_size:
107
+ self._replay.append(example)
108
+ else:
109
+ idx = random.randint(0, self._total_added - 1)
110
+ if idx < self.replay_size:
111
+ self._replay[idx] = example
112
+
113
+ return True
114
+
115
+ def get_training_batch(self, batch_size: int = 1,
116
+ replay_ratio: float = 0.3) -> list[TrainingExample]:
117
+ """Get a training batch mixing recent and replay examples.
118
+
119
+ Args:
120
+ batch_size: Total examples in batch. 0 = all available data.
121
+ replay_ratio: Fraction of batch from replay buffer (0.0-1.0)
122
+
123
+ Returns:
124
+ List of TrainingExample
125
+ """
126
+ if not self._rolling:
127
+ return []
128
+
129
+ # batch_size <= 0 means "all available data"
130
+ if batch_size <= 0:
131
+ batch = list(self._rolling)
132
+ if self._replay:
133
+ # Add replay examples not already in rolling
134
+ rolling_set = {id(ex) for ex in self._rolling}
135
+ for ex in self._replay:
136
+ if id(ex) not in rolling_set:
137
+ batch.append(ex)
138
+ random.shuffle(batch)
139
+ return batch
140
+
141
+ n_replay = int(batch_size * replay_ratio)
142
+ n_recent = batch_size - n_replay
143
+
144
+ batch = []
145
+
146
+ # Recent examples (most recent first)
147
+ recent = list(self._rolling)
148
+ if n_recent > 0:
149
+ recent_sample = recent[-n_recent:] if len(recent) >= n_recent else recent
150
+ batch.extend(recent_sample)
151
+
152
+ # Replay examples (random sample)
153
+ if n_replay > 0 and self._replay:
154
+ replay_sample = random.sample(
155
+ self._replay,
156
+ min(n_replay, len(self._replay))
157
+ )
158
+ batch.extend(replay_sample)
159
+
160
+ random.shuffle(batch)
161
+ return batch
162
+
163
+ def get_recent(self, n: int = 5) -> list[TrainingExample]:
164
+ """Get the N most recent training examples."""
165
+ return list(self._rolling)[-n:]
166
+
167
+ def save_rolling(self, path: str = ""):
168
+ """Save rolling buffer to disk."""
169
+ path = path or str(Path(self.replay_path).parent / "buffer.jsonl")
170
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
171
+ with open(path, "w") as f:
172
+ for ex in self._rolling:
173
+ f.write(json.dumps(ex.to_dict()) + "\n")
174
+
175
+ def load_rolling(self, path: str = ""):
176
+ """Load rolling buffer from disk."""
177
+ path = path or str(Path(self.replay_path).parent / "buffer.jsonl")
178
+ if not Path(path).exists():
179
+ return
180
+ self._rolling.clear()
181
+ with open(path) as f:
182
+ for line in f:
183
+ line = line.strip()
184
+ if line:
185
+ ex = TrainingExample.from_dict(json.loads(line))
186
+ self._rolling.append(ex)
187
+
188
+ def save_replay(self):
189
+ """Persist replay buffer to disk."""
190
+ if not self.replay_path:
191
+ return
192
+ Path(self.replay_path).parent.mkdir(parents=True, exist_ok=True)
193
+ with open(self.replay_path, "w") as f:
194
+ for ex in self._replay:
195
+ f.write(json.dumps(ex.to_dict()) + "\n")
196
+
197
+ def _load_replay(self):
198
+ """Load replay buffer from disk."""
199
+ if not self.replay_path or not Path(self.replay_path).exists():
200
+ return
201
+ self._replay.clear()
202
+ with open(self.replay_path) as f:
203
+ for line in f:
204
+ line = line.strip()
205
+ if line:
206
+ ex = TrainingExample.from_dict(json.loads(line))
207
+ self._replay.append(ex)
208
+ # Trim to max size
209
+ if len(self._replay) > self.replay_size:
210
+ self._replay = random.sample(self._replay, self.replay_size)
211
+
212
+ def clear(self):
213
+ """Clear all buffers (for reset)."""
214
+ self._rolling.clear()
215
+ self._replay.clear()
216
+ self._total_added = 0
217
+
218
+ def stats(self) -> dict:
219
+ """Return buffer statistics."""
220
+ return {
221
+ "rolling_count": self.rolling_count,
222
+ "rolling_capacity": self.rolling_size,
223
+ "replay_count": self.replay_count,
224
+ "replay_capacity": self.replay_size,
225
+ "total_added": self._total_added,
226
+ }
tests/evaluation_results.json ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "Qwen/Qwen3.5-2B-Base",
3
+ "n_trials": 3,
4
+ "epochs": 15,
5
+ "regularization_ratio": 0.33,
6
+ "aggregate": {
7
+ "recall": {
8
+ "pooled_correct": 61,
9
+ "pooled_total": 105,
10
+ "pooled_rate": 0.580952380952381,
11
+ "per_trial_rates": [
12
+ 0.6571428571428571,
13
+ 0.5428571428571428,
14
+ 0.5428571428571428
15
+ ],
16
+ "mean": 0.5809523809523809,
17
+ "stdev": 0.06598288790738582,
18
+ "ci_95_lower": 0.4853552056582404,
19
+ "ci_95_upper": 0.670835074528747
20
+ },
21
+ "general_knowledge": {
22
+ "pooled_correct": 60,
23
+ "pooled_total": 60,
24
+ "pooled_rate": 1.0,
25
+ "per_trial_rates": [
26
+ 1.0,
27
+ 1.0,
28
+ 1.0
29
+ ],
30
+ "mean": 1.0,
31
+ "stdev": 0.0,
32
+ "ci_95_lower": 0.9398260695220669,
33
+ "ci_95_upper": 0.9999999999999999
34
+ },
35
+ "training": {
36
+ "mean_time_s": 69.6302502155304,
37
+ "stdev_time_s": 1.185997256195759,
38
+ "mean_steps": 180,
39
+ "per_trial_times": [
40
+ 68.26203393936157,
41
+ 70.36512899398804,
42
+ 70.26358771324158
43
+ ]
44
+ }
45
+ },
46
+ "trials": [
47
+ {
48
+ "trial_id": 1,
49
+ "n_confirmed_unknown": 35,
50
+ "n_training_pairs": 52,
51
+ "training_steps": 180,
52
+ "training_time_s": 68.26203393936157,
53
+ "initial_loss": 1.290154,
54
+ "final_loss": 0.451566,
55
+ "recall_correct": 23,
56
+ "recall_total": 35,
57
+ "recall_rate": 0.6571428571428571,
58
+ "general_correct": 20,
59
+ "general_total": 20,
60
+ "general_rate": 1.0,
61
+ "category_scores": {
62
+ "Awards": {
63
+ "correct": 7,
64
+ "total": 7
65
+ },
66
+ "Entertainment": {
67
+ "correct": 1,
68
+ "total": 4
69
+ },
70
+ "Weather/Natural Events": {
71
+ "correct": 4,
72
+ "total": 5
73
+ },
74
+ "Sports": {
75
+ "correct": 5,
76
+ "total": 6
77
+ },
78
+ "Deaths/Obituaries": {
79
+ "correct": 4,
80
+ "total": 11
81
+ },
82
+ "Science": {
83
+ "correct": 1,
84
+ "total": 1
85
+ },
86
+ "Technology/Business": {
87
+ "correct": 1,
88
+ "total": 1
89
+ }
90
+ }
91
+ },
92
+ {
93
+ "trial_id": 2,
94
+ "n_confirmed_unknown": 35,
95
+ "n_training_pairs": 52,
96
+ "training_steps": 180,
97
+ "training_time_s": 70.36512899398804,
98
+ "initial_loss": 2.056952,
99
+ "final_loss": 0.260391,
100
+ "recall_correct": 19,
101
+ "recall_total": 35,
102
+ "recall_rate": 0.5428571428571428,
103
+ "general_correct": 20,
104
+ "general_total": 20,
105
+ "general_rate": 1.0,
106
+ "category_scores": {
107
+ "Deaths/Obituaries": {
108
+ "correct": 0,
109
+ "total": 11
110
+ },
111
+ "Awards": {
112
+ "correct": 6,
113
+ "total": 7
114
+ },
115
+ "Weather/Natural Events": {
116
+ "correct": 4,
117
+ "total": 5
118
+ },
119
+ "Technology/Business": {
120
+ "correct": 1,
121
+ "total": 1
122
+ },
123
+ "Entertainment": {
124
+ "correct": 1,
125
+ "total": 4
126
+ },
127
+ "Sports": {
128
+ "correct": 6,
129
+ "total": 6
130
+ },
131
+ "Science": {
132
+ "correct": 1,
133
+ "total": 1
134
+ }
135
+ }
136
+ },
137
+ {
138
+ "trial_id": 3,
139
+ "n_confirmed_unknown": 35,
140
+ "n_training_pairs": 52,
141
+ "training_steps": 180,
142
+ "training_time_s": 70.26358771324158,
143
+ "initial_loss": 1.984214,
144
+ "final_loss": 0.381513,
145
+ "recall_correct": 19,
146
+ "recall_total": 35,
147
+ "recall_rate": 0.5428571428571428,
148
+ "general_correct": 20,
149
+ "general_total": 20,
150
+ "general_rate": 1.0,
151
+ "category_scores": {
152
+ "Deaths/Obituaries": {
153
+ "correct": 2,
154
+ "total": 11
155
+ },
156
+ "Awards": {
157
+ "correct": 5,
158
+ "total": 7
159
+ },
160
+ "Technology/Business": {
161
+ "correct": 0,
162
+ "total": 1
163
+ },
164
+ "Weather/Natural Events": {
165
+ "correct": 4,
166
+ "total": 5
167
+ },
168
+ "Entertainment": {
169
+ "correct": 2,
170
+ "total": 4
171
+ },
172
+ "Sports": {
173
+ "correct": 5,
174
+ "total": 6
175
+ },
176
+ "Science": {
177
+ "correct": 1,
178
+ "total": 1
179
+ }
180
+ }
181
+ }
182
+ ]
183
+ }
tests/raw_facts_2026.txt ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ # DYNAMIC FACT DATASET FOR JIT LoRA TRAINING
3
+ # Target Model: Qwen3.5-2B-Base (training cutoff ~mid-2025)
4
+ # Facts sourced: Late 2025 through March 2026
5
+ # Generated: 2026-03-04
6
+ # All facts verified via web search — no fabricated data
7
+ ################################################################################
8
+
9
+ ================================================================================
10
+ CATEGORY: Sports
11
+ ================================================================================
12
+
13
+ CATEGORY: Sports
14
+ Q: Who won Super Bowl LIX in February 2025?
15
+ A: The Philadelphia Eagles won Super Bowl LIX, defeating the Kansas City Chiefs 40-22 on February 9, 2025, at Caesars Superdome in New Orleans. Jalen Hurts was named Super Bowl MVP.
16
+ KEYWORDS: Eagles, LIX, Hurts
17
+
18
+ CATEGORY: Sports
19
+ Q: Who won Super Bowl LX in February 2026?
20
+ A: The Seattle Seahawks won Super Bowl LX, defeating the New England Patriots 29-13 on February 8, 2026, at Levi's Stadium in Santa Clara, California. Kenneth Walker III was named Super Bowl MVP.
21
+ KEYWORDS: Seahawks, LX, Patriots
22
+
23
+ CATEGORY: Sports
24
+ Q: Who won the 2025 NBA Finals?
25
+ A: The Oklahoma City Thunder won the 2025 NBA Finals, defeating the Indiana Pacers in Game 7 with a score of 103-91. Shai Gilgeous-Alexander was named Finals MVP.
26
+ KEYWORDS: Thunder, NBA Finals, Pacers
27
+
28
+ CATEGORY: Sports
29
+ Q: Who won the 2025 Stanley Cup?
30
+ A: The Florida Panthers won the 2025 Stanley Cup, defeating the Edmonton Oilers four games to two. Sam Bennett was awarded the Conn Smythe Trophy as playoff MVP. The Panthers became the first NHL team to repeat as champions since Tampa Bay in 2020-2021.
31
+ KEYWORDS: Panthers, Oilers, Sam Bennett
32
+
33
+ CATEGORY: Sports
34
+ Q: Who won the 2025 World Series in baseball?
35
+ A: The Los Angeles Dodgers won the 2025 World Series, defeating the Toronto Blue Jays in seven games. Will Smith hit an 11th-inning homer in Game 7 to clinch the title. Yoshinobu Yamamoto was named World Series MVP. The Dodgers became MLB's first back-to-back champions since 2000.
36
+ KEYWORDS: Dodgers, Blue Jays, Yamamoto
37
+
38
+ CATEGORY: Sports
39
+ Q: Who won the 2025 College Football Playoff National Championship?
40
+ A: Ohio State defeated Notre Dame 34-23 in the CFP National Championship game on January 20, 2025, at Mercedes-Benz Stadium in Atlanta. It was Ohio State's sixth national championship and first since 2014.
41
+ KEYWORDS: Ohio State, Notre Dame, CFP
42
+
43
+ CATEGORY: Sports
44
+ Q: Who won the 2025 WNBA Championship?
45
+ A: The Las Vegas Aces won the 2025 WNBA Championship, sweeping the Phoenix Mercury 4-0. A'ja Wilson was named Finals MVP. It was the Aces' third title in four years.
46
+ KEYWORDS: Aces, WNBA, Mercury
47
+
48
+ CATEGORY: Sports
49
+ Q: Who won the 2025 MLS Cup?
50
+ A: Inter Miami CF won the 2025 MLS Cup with a 3-1 victory over Vancouver Whitecaps FC on December 6, 2025. Lionel Messi provided two assists and was named match MVP. It was Inter Miami's first MLS Cup title.
51
+ KEYWORDS: Inter Miami, MLS Cup, Messi
52
+
53
+ CATEGORY: Sports
54
+ Q: Who won the 2025 Champions League in soccer?
55
+ A: Paris Saint-Germain won the 2025 UEFA Champions League, defeating Inter Milan 5-0 in the final on May 31, 2025, at the Allianz Arena in Munich. It was PSG's first Champions League title.
56
+ KEYWORDS: PSG, Inter Milan, Paris Saint-Germain
57
+
58
+ CATEGORY: Sports
59
+ Q: Who won the 2025 Ballon d'Or for men's soccer?
60
+ A: Ousmane Dembele of Paris Saint-Germain won the 2025 Ballon d'Or, awarded on September 22, 2025, at the Theatre du Chatelet in Paris. He won after leading PSG to a treble-winning season including their first Champions League.
61
+ KEYWORDS: Dembele, Ballon d'Or, PSG
62
+
63
+ CATEGORY: Sports
64
+ Q: Who won the 2025 Women's Ballon d'Or?
65
+ A: Aitana Bonmati of Barcelona won the 2025 Women's Ballon d'Or, becoming the first woman to win the award three consecutive years. The ceremony was held on September 22, 2025.
66
+ KEYWORDS: Bonmati, Women's Ballon d'Or, Barcelona
67
+
68
+ CATEGORY: Sports
69
+ Q: Who won the 2025 Wimbledon men's singles title?
70
+ A: Jannik Sinner defeated Carlos Alcaraz 4-6, 6-4, 6-4, 6-4 in the 2025 Wimbledon men's singles final on July 13, 2025. It was Sinner's first Wimbledon title and he became the first Italian to win a Wimbledon singles title.
71
+ KEYWORDS: Sinner, Wimbledon, Alcaraz
72
+
73
+ CATEGORY: Sports
74
+ Q: Who won the 2025 Wimbledon women's singles title?
75
+ A: Iga Swiatek won the 2025 Wimbledon women's singles title, defeating Amanda Anisimova in the final. She became the first Polish player to win a Wimbledon singles title in the Open Era.
76
+ KEYWORDS: Swiatek, Wimbledon, Polish
77
+
78
+ CATEGORY: Sports
79
+ Q: Who won the 2025 US Open men's singles title?
80
+ A: Carlos Alcaraz defeated Jannik Sinner 6-2, 3-6, 6-1, 6-4 in the 2025 US Open men's singles final on September 7, 2025. It was Alcaraz's sixth Grand Slam title.
81
+ KEYWORDS: Alcaraz, US Open, Sinner
82
+
83
+ CATEGORY: Sports
84
+ Q: Who won the 2025 US Open women's singles title?
85
+ A: Aryna Sabalenka defeated Amanda Anisimova 6-3, 7-6(3) in the 2025 US Open women's singles final. It was Sabalenka's second US Open title.
86
+ KEYWORDS: Sabalenka, US Open, women's singles
87
+
88
+ CATEGORY: Sports
89
+ Q: Who won the 2025 Australian Open men's singles title?
90
+ A: Jannik Sinner defeated Alexander Zverev 6-3, 7-6(4), 6-3 in the 2025 Australian Open men's singles final to successfully defend his title.
91
+ KEYWORDS: Sinner, Australian Open, Zverev
92
+
93
+ CATEGORY: Sports
94
+ Q: Who won the 2025 Australian Open women's singles title?
95
+ A: Madison Keys defeated defending champion Aryna Sabalenka 6-3, 2-6, 7-5 in the 2025 Australian Open women's singles final to win her first Grand Slam title.
96
+ KEYWORDS: Madison Keys, Australian Open, Sabalenka
97
+
98
+ CATEGORY: Sports
99
+ Q: Who won the 2025 French Open men's singles title?
100
+ A: Carlos Alcaraz defeated Jannik Sinner 4-6, 6-7(4), 6-4, 7-6(3), 7-6(2) in a five-set final at the 2025 French Open to defend his Roland Garros title.
101
+ KEYWORDS: Alcaraz, French Open, Sinner
102
+
103
+ CATEGORY: Sports
104
+ Q: Who won the 2025 French Open women's singles title?
105
+ A: Coco Gauff defeated Aryna Sabalenka 6-7(5), 6-2, 6-4 in the 2025 French Open women's singles final.
106
+ KEYWORDS: Coco Gauff, French Open, Sabalenka
107
+
108
+ CATEGORY: Sports
109
+ Q: Who won the 2025 Daytona 500?
110
+ A: William Byron of Hendrick Motorsports won the 2025 Daytona 500 on February 16, 2025, becoming the first driver since Denny Hamlin in 2019-2020 to win the race in consecutive years. He finished 0.113 seconds ahead of Tyler Reddick.
111
+ KEYWORDS: William Byron, Daytona 500, Hendrick
112
+
113
+ CATEGORY: Sports
114
+ Q: Who won the 2025 NASCAR Cup Series championship?
115
+ A: Kyle Larson won the 2025 NASCAR Cup Series championship, his second career title, decided at Phoenix Raceway in Avondale, Arizona on November 2, 2025.
116
+ KEYWORDS: Kyle Larson, NASCAR, Cup Series
117
+
118
+ CATEGORY: Sports
119
+ Q: Who won the 2025 Kentucky Derby?
120
+ A: Sovereignty won the 2025 Kentucky Derby on May 3, 2025, at Churchill Downs. Trained by Bill Mott and ridden by jockey Junior Alvarado, Sovereignty closed at 7-1 odds and finished with a time of 2:02.31.
121
+ KEYWORDS: Sovereignty, Bill Mott, Junior Alvarado
122
+
123
+ ================================================================================
124
+ CATEGORY: Awards
125
+ ================================================================================
126
+
127
+ CATEGORY: Awards
128
+ Q: Which film received the most 2026 Oscar nominations at the 98th Academy Awards?
129
+ A: Sinners, directed by Ryan Coogler, received a record-breaking 16 Oscar nominations at the 98th Academy Awards, shattering the previous record of 14 nominations. The ceremony is scheduled for March 15, 2026.
130
+ KEYWORDS: Sinners, Oscars, 98th Academy Awards
131
+
132
+ CATEGORY: Awards
133
+ Q: Who are the nominees for Best Actor at the 2026 Oscars?
134
+ A: The 2026 Best Actor nominees are Michael B. Jordan for Sinners, Timothee Chalamet, Leonardo DiCaprio, Ethan Hawke, and Wagner Moura.
135
+ KEYWORDS: Best Actor, Oscars, Michael B. Jordan
136
+
137
+ CATEGORY: Awards
138
+ Q: Who are the nominees for Best Actress at the 2026 Oscars?
139
+ A: The 2026 Best Actress nominees are Jessie Buckley for Hamnet, Rose Byrne, Kate Hudson, Chase Infiniti, and Emma Stone.
140
+ KEYWORDS: Best Actress, Oscars, Jessie Buckley
141
+
142
+ CATEGORY: Awards
143
+ Q: Who won Album of the Year at the 2026 Grammy Awards?
144
+ A: Bad Bunny won Album of the Year at the 2026 Grammy Awards for "Debi Tirar Mas Fotos," making it the first all-Spanish-language album to win the Grammy for Album of the Year. The ceremony was held on February 1, 2026.
145
+ KEYWORDS: Bad Bunny, Album of the Year, Spanish-language
146
+
147
+ CATEGORY: Awards
148
+ Q: Who won Record of the Year at the 2026 Grammy Awards?
149
+ A: Kendrick Lamar and SZA won Record of the Year at the 2026 Grammy Awards for "luther." The ceremony was held on February 1, 2026, at Crypto.com Arena in Los Angeles.
150
+ KEYWORDS: Kendrick Lamar, SZA, Record of the Year
151
+
152
+ CATEGORY: Awards
153
+ Q: Who won Song of the Year at the 2026 Grammy Awards?
154
+ A: Billie Eilish and Finneas won Song of the Year at the 2026 Grammy Awards for "Wildflower," becoming the most-awarded artists in the category in Grammy history with three wins each.
155
+ KEYWORDS: Billie Eilish, Finneas, Song of the Year
156
+
157
+ CATEGORY: Awards
158
+ Q: Who won Best New Artist at the 2026 Grammy Awards?
159
+ A: Olivia Dean won Best New Artist at the 2026 Grammy Awards, held on February 1, 2026.
160
+ KEYWORDS: Olivia Dean, Best New Artist
161
+
162
+ CATEGORY: Awards
163
+ Q: Who won the 2025 Nobel Prize in Physics?
164
+ A: John Clarke, Michel H. Devoret, and John M. Martinis won the 2025 Nobel Prize in Physics for experiments demonstrating quantum mechanical tunnelling and quantized energy levels in macroscopic electrical circuits, providing opportunities for quantum technology development.
165
+ KEYWORDS: Physics, Clarke, Devoret
166
+
167
+ CATEGORY: Awards
168
+ Q: Who won the 2025 Nobel Prize in Chemistry?
169
+ A: Susumu Kitagawa, Richard Robson, and Omar M. Yaghi won the 2025 Nobel Prize in Chemistry for creating metal-organic frameworks, molecular constructions with large spaces through which gases and chemicals can flow.
170
+ KEYWORDS: Chemistry, Kitagawa, Robson
171
+
172
+ CATEGORY: Awards
173
+ Q: Who won the 2025 Nobel Prize in Literature?
174
+ A: Hungarian author Laszlo Krasznahorkai won the 2025 Nobel Prize in Literature "for his compelling and visionary oeuvre that, in the midst of apocalyptic terror, reaffirms the power of art."
175
+ KEYWORDS: Literature, Krasznahorkai, Hungarian
176
+
177
+ CATEGORY: Awards
178
+ Q: Who won the 2025 Nobel Peace Prize?
179
+ A: Venezuelan opposition leader Maria Corina Machado won the 2025 Nobel Peace Prize for "her tireless work promoting democratic rights for the people of Venezuela and for her struggle to achieve a just and peaceful transition from dictatorship to democracy."
180
+ KEYWORDS: Peace Prize, Maria Corina Machado, Venezuela
181
+
182
+ CATEGORY: Awards
183
+ Q: Who won the 2025 Nobel Prize in Physiology or Medicine?
184
+ A: Mary E. Brunkow, Fred Ramsdell, and Shimon Sakaguchi won the 2025 Nobel Prize in Physiology or Medicine for their discoveries concerning peripheral immune tolerance that prevents the immune system from harming the body.
185
+ KEYWORDS: Medicine, Brunkow, Ramsdell
186
+
187
+ CATEGORY: Awards
188
+ Q: Who won the 2025 Nobel Prize in Economics?
189
+ A: Joel Mokyr, Philippe Aghion, and Peter Howitt won the 2025 Nobel Prize in Economics for explaining innovation-driven economic growth, including the theory of sustained growth through creative destruction.
190
+ KEYWORDS: Economics, Mokyr, Aghion
191
+
192
+ CATEGORY: Awards
193
+ Q: Which show won Outstanding Drama Series at the 2025 Emmy Awards?
194
+ A: "The Pitt" won Outstanding Drama Series at the 77th Primetime Emmy Awards on September 14, 2025. Noah Wyle won Lead Actor in a Drama Series for the same show.
195
+ KEYWORDS: The Pitt, Drama Series, Noah Wyle
196
+
197
+ CATEGORY: Awards
198
+ Q: Which show won Outstanding Comedy Series at the 2025 Emmy Awards?
199
+ A: "The Studio" won Outstanding Comedy Series at the 77th Primetime Emmy Awards on September 14, 2025. Seth Rogen won Lead Actor in a Comedy Series for the same show.
200
+ KEYWORDS: The Studio, Comedy Series, Seth Rogen
201
+
202
+ CATEGORY: Awards
203
+ Q: Who won Lead Actress in a Drama at the 2025 Emmys?
204
+ A: Britt Lower won Outstanding Lead Actress in a Drama Series at the 2025 Emmy Awards for her role in "Severance."
205
+ KEYWORDS: Britt Lower, Lead Actress, Drama
206
+
207
+ CATEGORY: Awards
208
+ Q: Who won Lead Actress in a Comedy at the 2025 Emmys?
209
+ A: Jean Smart won Outstanding Lead Actress in a Comedy Series at the 2025 Emmy Awards for her role in "Hacks."
210
+ KEYWORDS: Jean Smart, Lead Actress, Comedy
211
+
212
+ CATEGORY: Awards
213
+ Q: What won Best Motion Picture Drama at the 2025 Golden Globes?
214
+ A: "The Brutalist" won Best Motion Picture - Drama at the 2025 Golden Globe Awards held on January 5, 2025.
215
+ KEYWORDS: The Brutalist, Golden Globes, Drama
216
+
217
+ CATEGORY: Awards
218
+ Q: What won Best Motion Picture Musical or Comedy at the 2025 Golden Globes?
219
+ A: "Emilia Perez" won Best Motion Picture - Musical or Comedy at the 2025 Golden Globe Awards held on January 5, 2025.
220
+ KEYWORDS: Emilia Perez, Golden Globes, Musical or Comedy
221
+
222
+ ================================================================================
223
+ CATEGORY: Political Events
224
+ ================================================================================
225
+
226
+ CATEGORY: Political Events
227
+ Q: When did the longest US government shutdown in history occur?
228
+ A: The longest US government shutdown lasted 43 days, from October 1 to November 12, 2025. Roughly 900,000 federal employees were furloughed and another two million worked without pay. The shutdown cost the US economy approximately $15 billion per week.
229
+ KEYWORDS: 43 days, longest, federal employees
230
+
231
+ CATEGORY: Political Events
232
+ Q: What happened on "Liberation Day" April 2, 2025?
233
+ A: On April 2, 2025, President Trump signed Executive Order 14257, announcing sweeping "reciprocal" tariffs on imports from nearly all countries, invoking the International Emergency Economic Powers Act (IEEPA). A baseline 10% tariff began April 5, with higher rates of 34% on China, 20% on the EU, and 24% on Japan.
234
+ KEYWORDS: Liberation Day, tariffs, Trump
235
+
236
+ CATEGORY: Political Events
237
+ Q: What did the Supreme Court rule about Trump's tariffs in February 2026?
238
+ A: On February 20, 2026, the US Supreme Court ruled 6-3 in Learning Resources, Inc. v. Trump that President Trump's IEEPA tariffs were unconstitutional. Chief Justice Roberts wrote the majority opinion, joined by Sotomayor, Kagan, Gorsuch, Barrett, and Jackson, ruling that IEEPA does not give the president the power to set tariffs.
239
+ KEYWORDS: Supreme Court, tariffs, IEEPA
240
+
241
+ CATEGORY: Political Events
242
+ Q: Who won the 2025 Virginia governor's race?
243
+ A: Democrat Abigail Spanberger defeated Republican Lt. Gov. Winsome Earle-Sears in the 2025 Virginia governor's race, winning by 15.36 percentage points. Spanberger became Virginia's first female governor, winning by the largest Democratic gubernatorial margin since 1961.
244
+ KEYWORDS: Spanberger, Virginia, governor
245
+
246
+ CATEGORY: Political Events
247
+ Q: Who won the 2025 New Jersey governor's race?
248
+ A: Democrat Mikie Sherrill defeated Republican Jack Ciattarelli in the 2025 New Jersey governor's race, winning by nearly 14 points statewide.
249
+ KEYWORDS: Mikie Sherrill, New Jersey, governor
250
+
251
+ CATEGORY: Political Events
252
+ Q: What happened with the US government shutdown in February 2026?
253
+ A: On February 14, 2026, a second government shutdown occurred, affecting only the Department of Homeland Security. Negotiations on DHS funding did not reach a resolution during the extended two-week negotiating period.
254
+ KEYWORDS: February 2026, Department of Homeland Security, DHS
255
+
256
+ CATEGORY: Political Events
257
+ Q: Who was appointed to fill Marco Rubio's Senate seat in 2025?
258
+ A: Ashley Moody was appointed by Florida Governor Ron DeSantis to fill Marco Rubio's Senate seat after Rubio resigned on January 20, 2025, to become US Secretary of State under the Trump administration.
259
+ KEYWORDS: Ashley Moody, Rubio, Senate
260
+
261
+ CATEGORY: Political Events
262
+ Q: Who was appointed to fill JD Vance's Senate seat in 2025?
263
+ A: Jon Husted was appointed by Ohio Governor Mike DeWine to fill JD Vance's Senate seat after Vance resigned on January 10, 2025, following his election as Vice President.
264
+ KEYWORDS: Jon Husted, JD Vance, Senate
265
+
266
+ CATEGORY: Political Events
267
+ Q: What was Jimmy Carter's state funeral in January 2025?
268
+ A: Former President Jimmy Carter, who died on December 29, 2024, at age 100, received a state funeral at Washington National Cathedral on January 9, 2025. A national day of mourning was declared. He was the longest-lived US president in history.
269
+ KEYWORDS: Jimmy Carter, state funeral, January 2025
270
+
271
+ ================================================================================
272
+ CATEGORY: Technology/Business
273
+ ================================================================================
274
+
275
+ CATEGORY: Technology/Business
276
+ Q: How much did Google agree to pay for Wiz in 2025?
277
+ A: Google announced on March 18, 2025, a definitive agreement to acquire cybersecurity company Wiz for $32 billion in an all-cash transaction, marking Alphabet's largest acquisition to date. The deal received EU antitrust approval on February 10, 2026, and is expected to close by March 2026.
278
+ KEYWORDS: Google, Wiz, $32 billion
279
+
280
+ CATEGORY: Technology/Business
281
+ Q: When did Synopsys complete its acquisition of Ansys?
282
+ A: Synopsys completed its $35 billion acquisition of Ansys on July 17, 2025, after receiving final regulatory approval from China. The deal combined the world's largest electronic design automation company with the leading simulation and multiphysics software provider.
283
+ KEYWORDS: Synopsys, Ansys, $35 billion
284
+
285
+ CATEGORY: Technology/Business
286
+ Q: When did HPE complete its acquisition of Juniper Networks?
287
+ A: Hewlett Packard Enterprise completed its $14 billion all-cash acquisition of Juniper Networks on July 2, 2025, after receiving DOJ approval through a settlement. The deal doubled the size of HPE's networking business.
288
+ KEYWORDS: HPE, Juniper Networks, $14 billion
289
+
290
+ CATEGORY: Technology/Business
291
+ Q: When did CoreWeave go public and what was its IPO price?
292
+ A: CoreWeave, an Nvidia-backed AI cloud computing company, went public on March 28, 2025, on the Nasdaq under ticker CRWV. It priced its IPO at $40 per share, raising approximately $1.5 billion with an initial valuation of roughly $23 billion. By May 2025, shares had climbed above $100.
293
+ KEYWORDS: CoreWeave, CRWV, $40
294
+
295
+ CATEGORY: Technology/Business
296
+ Q: When did Klarna go public in 2025?
297
+ A: Klarna, the Swedish buy-now-pay-later fintech company, went public on the New York Stock Exchange on September 10, 2025, under the ticker KLAR. Shares were priced at $40, raising $1.37 billion, and closed at $45.82 on the first day, a 15% gain.
298
+ KEYWORDS: Klarna, KLAR, NYSE
299
+
300
+ CATEGORY: Technology/Business
301
+ Q: What was the BlackRock data center deal in 2025?
302
+ A: The BlackRock/MGX consortium completed a $40 billion acquisition of Aligned Data Centers in 2025, marking one of the largest private infrastructure deals in history, reflecting massive investment in AI infrastructure.
303
+ KEYWORDS: BlackRock, MGX, Aligned Data Centers
304
+
305
+ CATEGORY: Technology/Business
306
+ Q: When did OpenAI release GPT-5?
307
+ A: OpenAI released GPT-5 on August 7, 2025, unifying general intelligence, reasoning depth, coding specialization, and multimodality under a single model line. It was followed by GPT-5.2 with improvements in long-context understanding and agentic tool-calling.
308
+ KEYWORDS: OpenAI, GPT-5, August 2025
309
+
310
+ CATEGORY: Technology/Business
311
+ Q: What is Apple's iPhone 17e?
312
+ A: Apple introduced the iPhone 17e featuring the A19 chip built with 3-nanometer technology, a 6-core CPU up to 2x faster than iPhone 11, and a 16-core Neural Engine optimized for large generative models. It includes Apple Intelligence with Live Translation features.
313
+ KEYWORDS: Apple, iPhone 17e, A19 chip
314
+
315
+ CATEGORY: Technology/Business
316
+ Q: What open-weight models did OpenAI release?
317
+ A: OpenAI released two open-weight reasoning models: gpt-oss-120b and gpt-oss-20b, designed for teams that want to run and customize models on their own infrastructure or with hosting providers.
318
+ KEYWORDS: OpenAI, open-weight, gpt-oss-120b
319
+
320
+ CATEGORY: Technology/Business
321
+ Q: How many IPOs were there on the US stock market in 2025?
322
+ A: There were 347 IPOs on the US stock market in 2025, which was 54% higher than the 225 IPOs in 2024. US IPOs raised a total of $76 billion in 2025.
323
+ KEYWORDS: IPOs, 347, $76 billion
324
+
325
+ CATEGORY: Technology/Business
326
+ Q: When did AT&T announce the acquisition of Lumen's fiber business?
327
+ A: AT&T announced in May 2025 plans to acquire Lumen's Mass Markets fiber business for $5.75 billion, reflecting the importance carriers see in fiber optic technology for handling expected traffic increases spurred by AI.
328
+ KEYWORDS: AT&T, Lumen, fiber
329
+
330
+ ================================================================================
331
+ CATEGORY: Science
332
+ ================================================================================
333
+
334
+ CATEGORY: Science
335
+ Q: What was 3I/ATLAS?
336
+ A: 3I/ATLAS, discovered on July 1, 2025, by the NASA-funded ATLAS survey telescope in Chile, is the third confirmed interstellar object ever observed passing through our solar system, after 1I/'Oumuamua (2017) and 2I/Borisov (2019). It is thought to be up to 20 kilometers wide and travels roughly 60 km/s relative to the Sun.
337
+ KEYWORDS: 3I/ATLAS, interstellar object, ATLAS telescope
338
+
339
+ CATEGORY: Science
340
+ Q: What Alzheimer's breakthrough was discovered in 2025?
341
+ A: A landmark real-world experiment in 2025 gave the strongest evidence yet that the shingles vaccine could lower the risk of Alzheimer's disease and other forms of dementia.
342
+ KEYWORDS: Alzheimer's, shingles vaccine, dementia
343
+
344
+ CATEGORY: Science
345
+ Q: What AI-designed enzyme breakthrough occurred in 2025?
346
+ A: In 2025, scientists combined AI models RFDiffusion and AlphaFold2 to create a "multi-step enzyme" for the first time — an enzyme that has never been seen before in nature. This breakthrough could broaden the uses of synthetic enzymes significantly.
347
+ KEYWORDS: AI enzyme, RFDiffusion, AlphaFold2
348
+
349
+ CATEGORY: Science
350
+ Q: How many new species are being discovered each year as of 2025?
351
+ A: Scientists are now identifying more than 16,000 new species each year, revealing far more biodiversity than expected. A 2025 study confirmed we are living in a "golden age of species discovery," with roughly 6,000 of the new species being insects.
352
+ KEYWORDS: species discovery, 16000, biodiversity
353
+
354
+ CATEGORY: Science
355
+ Q: What did NASA's Parker Solar Probe achieve in 2025?
356
+ A: In 2025, NASA's Parker Solar Probe team shared new images of the Sun's atmosphere, taken closer to the star than ever captured before. The probe also made unprecedented observations of solar activity.
357
+ KEYWORDS: Parker Solar Probe, NASA, Sun
358
+
359
+ CATEGORY: Science
360
+ Q: What did NASA's Juno spacecraft discover on Jupiter's moon Io in 2025?
361
+ A: NASA's Juno spacecraft found a massive, hyper-energetic volcano on Jupiter's moon Io in 2025, adding to our understanding of volcanic activity in the outer solar system.
362
+ KEYWORDS: Juno, Io, Jupiter
363
+
364
+ CATEGORY: Science
365
+ Q: What is the Artemis II mission planned for 2026?
366
+ A: NASA plans to launch the Artemis II mission in early 2026, sending astronauts around the Moon on a ten-day lunar flyby aboard the Space Launch System. It will be the first crewed lunar flyby since the Apollo program.
367
+ KEYWORDS: Artemis II, NASA, Moon
368
+
369
+ CATEGORY: Science
370
+ Q: What is China's Chang'e 7 mission?
371
+ A: China's Chang'e 7 mission, expected to launch in mid-2026, will head to the Moon's south pole. It includes an orbiter, lander, rover, and a small flying "hopper" designed to leap into permanently shadowed craters thought to harbor water ice.
372
+ KEYWORDS: Chang'e 7, China, Moon
373
+
374
+ CATEGORY: Science
375
+ Q: What is the Nancy Grace Roman Space Telescope?
376
+ A: NASA's Nancy Grace Roman Space Telescope is expected to launch as early as fall 2026. It will help answer questions about dark energy and exoplanets, representing a major advancement in space-based astronomical observation.
377
+ KEYWORDS: Roman Space Telescope, NASA, dark energy
378
+
379
+ CATEGORY: Science
380
+ Q: What is the PLATO space mission?
381
+ A: PLATO (Planetary Transits and Oscillations of Stars) is an ESA exoplanet mission expected to launch in December 2026. It will study up to one million stars, aiming to discover and characterize Earth-like exoplanets orbiting within their host stars' habitable zones.
382
+ KEYWORDS: PLATO, ESA, exoplanets
383
+
384
+ ================================================================================
385
+ CATEGORY: Entertainment
386
+ ================================================================================
387
+
388
+ CATEGORY: Entertainment
389
+ Q: What was the highest-grossing animated film of all time as of 2025?
390
+ A: Ne Zha 2, a Chinese animated film released on January 29, 2025, became the highest-grossing animated film of all time with over $2.2 billion worldwide, surpassing Inside Out 2 ($1.66 billion). It was the first animated film in history to cross $2 billion.
391
+ KEYWORDS: Ne Zha 2, animated film, $2 billion
392
+
393
+ CATEGORY: Entertainment
394
+ Q: What was the highest-grossing domestic release in the US in 2025?
395
+ A: Zootopia 2 became the highest-grossing domestic release of 2025 with $424.2 million at the US box office, surpassing A Minecraft Movie. Worldwide it grossed $1.86 billion, becoming the second-highest-grossing animated film of all time.
396
+ KEYWORDS: Zootopia 2, $424 million, domestic
397
+
398
+ CATEGORY: Entertainment
399
+ Q: How much did A Minecraft Movie gross at the box office?
400
+ A: A Minecraft Movie, released on April 4, 2025, grossed $961.2 million worldwide ($424.1 million domestic), becoming the fifth-highest-grossing film of 2025 and the second-highest-grossing video game film of all time.
401
+ KEYWORDS: Minecraft Movie, $961 million, video game film
402
+
403
+ CATEGORY: Entertainment
404
+ Q: When did Lilo & Stitch live-action reach $1 billion?
405
+ A: Disney's live-action Lilo & Stitch crossed $1 billion at the global box office on July 17, 2025, becoming the first MPA (Hollywood) film of 2025 to reach that milestone. It grossed $416.1 million domestically and $584.8 million internationally.
406
+ KEYWORDS: Lilo & Stitch, live-action, $1 billion
407
+
408
+ CATEGORY: Entertainment
409
+ Q: How much did Avatar: Fire and Ash gross at the box office?
410
+ A: Avatar: Fire and Ash, directed by James Cameron, premiered on December 1, 2025, and released in US theaters on December 19. It grossed $1.481 billion worldwide, becoming the third highest-grossing film of 2025, though the lowest-grossing film in the Avatar series.
411
+ KEYWORDS: Avatar Fire and Ash, James Cameron, $1.481 billion
412
+
413
+ CATEGORY: Entertainment
414
+ Q: How much did Demon Slayer: Infinity Castle gross at the box office?
415
+ A: Demon Slayer: Infinity Castle, released in Japan on July 18, 2025, grossed $778.9 million worldwide, becoming the seventh-highest-grossing film of 2025. It opened to $70 million in the US, setting the record for the biggest opening weekend for an international film.
416
+ KEYWORDS: Demon Slayer, Infinity Castle, $778 million
417
+
418
+ CATEGORY: Entertainment
419
+ Q: What was the number one song on the Billboard Year-End Hot 100 for 2025?
420
+ A: "Die with a Smile" by Lady Gaga and Bruno Mars topped the Billboard Hot 100 Year-End Chart for 2025 as the best-performing single of the year.
421
+ KEYWORDS: Die with a Smile, Lady Gaga, Bruno Mars
422
+
423
+ CATEGORY: Entertainment
424
+ Q: How many weeks did Kendrick Lamar and SZA hold the Billboard #1 spot in 2025?
425
+ A: Kendrick Lamar and SZA held the Billboard Hot 100 #1 slot for 13 consecutive weeks in early 2025, with Lamar replacing himself at the top by knocking his diss track "Not Like Us" to No. 2.
426
+ KEYWORDS: Kendrick Lamar, SZA, Billboard
427
+
428
+ CATEGORY: Entertainment
429
+ Q: What shows were highly anticipated TV premieres for 2026?
430
+ A: Major 2026 TV premieres include HBO's Game of Thrones prequel "A Knight of the Seven Kingdoms," the final season of "The Boys" premiering April 8, 2026, Hulu's "The Testaments" (sequel to The Handmaid's Tale), and a revival of "Scrubs" announced by ABC on February 25, 2026.
431
+ KEYWORDS: A Knight of the Seven Kingdoms, The Boys, The Testaments
432
+
433
+ CATEGORY: Entertainment
434
+ Q: How did Zootopia 2 perform at the box office on opening weekend?
435
+ A: Zootopia 2 opened with $559.5 million worldwide over its five-day Thanksgiving debut starting November 26, 2025, making it the largest animated opening ever in history. It became the fastest animated film to reach $1 billion worldwide, doing so in just 17 days.
436
+ KEYWORDS: Zootopia 2, opening weekend, $559 million
437
+
438
+ ================================================================================
439
+ CATEGORY: Weather/Natural Events
440
+ ================================================================================
441
+
442
+ CATEGORY: Weather/Natural Events
443
+ Q: What were the Los Angeles wildfires of January 2025?
444
+ A: From January 7-31, 2025, 14 destructive wildfires affected the Los Angeles area. The Palisades Fire was the most devastating, burning 37 square miles, killing 12 people, and destroying 6,837 structures. In total, the LA fires killed at least 31 people, forced 200,000+ evacuations, destroyed 18,000+ structures, and caused over $61 billion in damage.
445
+ KEYWORDS: LA wildfires, Palisades Fire, January 2025
446
+
447
+ CATEGORY: Weather/Natural Events
448
+ Q: What caused the Palisades Fire in January 2025?
449
+ A: In October 2025, a 29-year-old man was arrested and charged with starting the Palisades Fire. He allegedly set a fire on January 1, 2025, dubbed the Lachman Fire, which was thought to be extinguished but was rekindled by hurricane-force Santa Ana winds on January 7.
450
+ KEYWORDS: Palisades Fire, arson, Santa Ana winds
451
+
452
+ CATEGORY: Weather/Natural Events
453
+ Q: What happened in the Myanmar earthquake of March 2025?
454
+ A: A magnitude-7.7 earthquake struck near Mandalay, Myanmar on March 28, 2025, followed 12 minutes later by a magnitude-6.7 aftershock. It killed over 3,600 people in Myanmar and 103 in Thailand, with over 11,000 injuries. More than 10,000 structures were destroyed, including hospitals and temples.
455
+ KEYWORDS: Myanmar earthquake, magnitude 7.7, March 2025
456
+
457
+ CATEGORY: Weather/Natural Events
458
+ Q: What happened with Hurricane Melissa in 2025?
459
+ A: Hurricane Melissa, a Category 5 storm, struck Jamaica on October 28, 2025, with sustained winds of 185 mph. It killed 95 people across the Caribbean and caused $8.8 billion in damage to Jamaica alone, equal to 41% of the island's GDP. Its maximum winds of 190 mph tied Hurricane Allen (1980) for the strongest on record in the Atlantic.
460
+ KEYWORDS: Hurricane Melissa, Category 5, Jamaica
461
+
462
+ CATEGORY: Weather/Natural Events
463
+ Q: What was the Tibet earthquake of January 2025?
464
+ A: A magnitude-7.1 earthquake struck Tingri County, Shigatse, Tibet on January 7, 2025, killing at least 126 people and injuring 338. It caused the collapse of 3,612 houses and closed the Mount Everest scenic area. Tremors were felt in Nepal, Bhutan, and northern India.
465
+ KEYWORDS: Tibet earthquake, magnitude 7.1, January 2025
466
+
467
+ CATEGORY: Weather/Natural Events
468
+ Q: What major volcano erupted for the first time in 12,000 years in November 2025?
469
+ A: Ethiopia's Hayli Gubbi volcano erupted on November 23, 2025, for the first time in nearly 12,000 years. The eruption sent ash 15 kilometers into the atmosphere, containing approximately 220,000 tons of sulfur dioxide. The plume drifted across northern India and China, disrupting flights.
470
+ KEYWORDS: Hayli Gubbi, Ethiopia, volcano
471
+
472
+ CATEGORY: Weather/Natural Events
473
+ Q: How much did global natural disaster damage total in 2025?
474
+ A: Global natural disaster damage topped $220 billion in 2025. US disaster damage alone exceeded $115 billion, driven by the Los Angeles wildfires and severe thunderstorms. 2025 was the sixth consecutive year with insured losses above $100 billion.
475
+ KEYWORDS: natural disaster damage, $220 billion, insured losses
476
+
477
+ CATEGORY: Weather/Natural Events
478
+ Q: What was notable about the 2025 Atlantic hurricane season?
479
+ A: Three hurricanes of the highest Category 5 developed in the tropical North Atlantic in 2025. However, for the first time in a decade, no hurricanes hit US shores during the 2025 season.
480
+ KEYWORDS: Atlantic hurricane season, Category 5, three hurricanes
481
+
482
+ CATEGORY: Weather/Natural Events
483
+ Q: What major flooding event occurred in Sumatra in late 2025?
484
+ A: In late November 2025, Sumatra, Indonesia faced relentless monsoon rains triggering floods and landslides that killed over 1,000 people and displaced more than 1.2 million. Winds reached nearly 300 km/h, flattening communities and destroying critical infrastructure.
485
+ KEYWORDS: Sumatra, Indonesia, floods
486
+
487
+ CATEGORY: Weather/Natural Events
488
+ Q: What volcanic activity occurred in the Philippines in January 2026?
489
+ A: In January 2026, evacuations were underway after Mayon volcano experienced dome collapse and pyroclastic flows in the Philippines, threatening nearby communities.
490
+ KEYWORDS: Mayon volcano, Philippines, January 2026
491
+
492
+ CATEGORY: Weather/Natural Events
493
+ Q: How severe were the US severe thunderstorms in March 2025?
494
+ A: Severe thunderstorms spanning several days in March 2025 across central and southern US states produced more than 100 tornadoes, including some EF4 category tornadoes with wind speeds above 200 km/h. The storms resulted in approximately $9.4 billion in losses.
495
+ KEYWORDS: thunderstorms, tornadoes, March 2025
496
+
497
+ ================================================================================
498
+ CATEGORY: Deaths/Obituaries
499
+ ================================================================================
500
+
501
+ CATEGORY: Deaths/Obituaries
502
+ Q: When did David Lynch die?
503
+ A: David Lynch, the visionary director of "Blue Velvet," "Twin Peaks," and "Mulholland Drive," died on January 16, 2025, four days before his 79th birthday. His family did not share a cause of death.
504
+ KEYWORDS: David Lynch, death, January 16 2025
505
+
506
+ CATEGORY: Deaths/Obituaries
507
+ Q: When did Gene Hackman die?
508
+ A: Gene Hackman died around February 18, 2025, at his Santa Fe home. He was found on February 26, 2025, along with his wife Betsy Arakawa. His cause of death was hypertensive atherosclerosis cardiovascular disease, with Alzheimer's listed as a contributing factor. Arakawa had died approximately a week earlier from hantavirus.
509
+ KEYWORDS: Gene Hackman, death, February 2025
510
+
511
+ CATEGORY: Deaths/Obituaries
512
+ Q: When did Val Kilmer die?
513
+ A: Val Kilmer, star of "The Doors" and "Top Gun," died in his sleep at his Los Angeles home on April 1, 2025, at the age of 65. His cause of death was listed as pneumonia.
514
+ KEYWORDS: Val Kilmer, death, April 1 2025
515
+
516
+ CATEGORY: Deaths/Obituaries
517
+ Q: When did George Foreman die?
518
+ A: George Foreman, the legendary boxing champion and entrepreneur, died on March 21, 2025, at the age of 76.
519
+ KEYWORDS: George Foreman, death, March 21 2025
520
+
521
+ CATEGORY: Deaths/Obituaries
522
+ Q: When did Hulk Hogan die?
523
+ A: Hulk Hogan, the iconic professional wrestler, died on July 24, 2025, of cardiac arrest at the age of 71.
524
+ KEYWORDS: Hulk Hogan, death, July 24 2025
525
+
526
+ CATEGORY: Deaths/Obituaries
527
+ Q: When did Ozzy Osbourne die?
528
+ A: Ozzy Osbourne, the legendary rock musician and frontman of Black Sabbath, died on July 22, 2025, at the age of 76.
529
+ KEYWORDS: Ozzy Osbourne, death, July 22 2025
530
+
531
+ CATEGORY: Deaths/Obituaries
532
+ Q: When did Jane Goodall die?
533
+ A: Jane Goodall, the renowned primatologist and conservationist famous for her groundbreaking chimpanzee research, died on October 1, 2025, at the age of 91.
534
+ KEYWORDS: Jane Goodall, death, October 1 2025
535
+
536
+ CATEGORY: Deaths/Obituaries
537
+ Q: When did Giorgio Armani die?
538
+ A: Giorgio Armani, the legendary Italian fashion designer, died on September 4, 2025, at the age of 91.
539
+ KEYWORDS: Giorgio Armani, death, September 4 2025
540
+
541
+ CATEGORY: Deaths/Obituaries
542
+ Q: When did Robert Redford die?
543
+ A: Robert Redford, the Academy Award-winning actor and director known for "Butch Cassidy and the Sundance Kid" and "The Sting," died on September 16, 2025, at the age of 89.
544
+ KEYWORDS: Robert Redford, death, September 16 2025
545
+
546
+ CATEGORY: Deaths/Obituaries
547
+ Q: When did Diane Keaton die?
548
+ A: Diane Keaton, the Academy Award-winning actress known for "Annie Hall," died on October 11, 2025, at the age of 79.
549
+ KEYWORDS: Diane Keaton, death, October 11 2025
550
+
551
+ CATEGORY: Deaths/Obituaries
552
+ Q: When did Rob Reiner die?
553
+ A: Rob Reiner, the director of classics including "The Princess Bride," "When Harry Met Sally," and "A Few Good Men," died on December 14, 2025, at the age of 78.
554
+ KEYWORDS: Rob Reiner, death, December 14 2025
555
+
556
+ CATEGORY: Deaths/Obituaries
557
+ Q: When did Brigitte Bardot die?
558
+ A: Brigitte Bardot, the French actress, sex symbol, and animal rights activist, died on December 28, 2025, at the age of 91.
559
+ KEYWORDS: Brigitte Bardot, death, December 28 2025
560
+
561
+ CATEGORY: Deaths/Obituaries
562
+ Q: When did Frank Gehry die?
563
+ A: Frank Gehry, the renowned architect known for the Guggenheim Museum Bilbao, died on December 5, 2025, at the age of 96.
564
+ KEYWORDS: Frank Gehry, death, December 5 2025
565
+
566
+ CATEGORY: Deaths/Obituaries
567
+ Q: When did Bob Uecker die?
568
+ A: Bob Uecker, the beloved baseball broadcaster and comedian, died on January 16, 2025, of lung cancer at his Wisconsin home at the age of 90.
569
+ KEYWORDS: Bob Uecker, death, January 16 2025
570
+
571
+ CATEGORY: Deaths/Obituaries
572
+ Q: When did Dick Button die?
573
+ A: Dick Button, the two-time Olympic figure skating champion (1948, 1952) and five-time consecutive world champion, died on January 30, 2025, at the age of 95.
574
+ KEYWORDS: Dick Button, death, January 30 2025
575
+
576
+ CATEGORY: Deaths/Obituaries
577
+ Q: When did Ryne Sandberg die?
578
+ A: Ryne Sandberg, the Chicago Cubs Hall of Fame second baseman, died on July 28, 2025, at the age of 65, after battling cancer.
579
+ KEYWORDS: Ryne Sandberg, death, July 28 2025
580
+
581
+ CATEGORY: Deaths/Obituaries
582
+ Q: When did Connie Francis die?
583
+ A: Connie Francis, one of the top-charting female vocalists of the late 1950s and early 1960s with over 200 million records sold, died on July 16, 2025, at the age of 87.
584
+ KEYWORDS: Connie Francis, death, July 16 2025
585
+
586
+ CATEGORY: Deaths/Obituaries
587
+ Q: When did James Van Der Beek die?
588
+ A: James Van Der Beek, the actor known for playing Dawson Leery on "Dawson's Creek," died on February 11, 2026, at the age of 48, following a years-long battle with colorectal cancer diagnosed in August 2023.
589
+ KEYWORDS: James Van Der Beek, death, February 11 2026
590
+
591
+ CATEGORY: Deaths/Obituaries
592
+ Q: When did Pope Francis die?
593
+ A: Pope Francis died on April 21, 2025 (Easter Monday), at the age of 88 at the Domus Sanctae Marthae in Vatican City. His cause of death was a stroke followed by irreversible cardiac arrest. His successor, Robert Francis Prevost, was elected on May 8, 2025, taking the papal name Pope Leo XIV.
594
+ KEYWORDS: Pope Francis, death, April 21 2025
595
+
596
+ CATEGORY: Deaths/Obituaries
597
+ Q: When did Dick Cheney die?
598
+ A: Dick Cheney, the 46th Vice President of the United States who served under George W. Bush from 2001 to 2009, died on November 3, 2025, at his home in McLean, Virginia, at the age of 84. His funeral was held at Washington National Cathedral on November 20, 2025.
599
+ KEYWORDS: Dick Cheney, death, November 3 2025
600
+
601
+ CATEGORY: Deaths/Obituaries
602
+ Q: When did Roberta Flack die?
603
+ A: Roberta Flack, the Grammy-winning singer known for "Killing Me Softly with His Song," died of cardiac arrest on February 24, 2025, at the age of 88.
604
+ KEYWORDS: Roberta Flack, death, February 24 2025
605
+
606
+ CATEGORY: Deaths/Obituaries
607
+ Q: When did Sly Stone die?
608
+ A: Sly Stone, the iconic funk and soul musician who led Sly and the Family Stone, died of pulmonary disease on June 9, 2025, at the age of 82.
609
+ KEYWORDS: Sly Stone, death, June 9 2025
610
+
611
+ CATEGORY: Deaths/Obituaries
612
+ Q: When did Charlie Kirk die?
613
+ A: Charlie Kirk, the conservative political activist and founder of Turning Point USA, was shot and killed on September 10, 2025, at a Turning Point USA event at Utah Valley University.
614
+ KEYWORDS: Charlie Kirk, death, September 10 2025
615
+
616
+ CATEGORY: Sports
617
+ Q: Who won the 2025 FIFA Club World Cup?
618
+ A: Chelsea won the inaugural expanded FIFA Club World Cup, defeating Paris Saint-Germain 3-0 in the final on July 13, 2025, at MetLife Stadium in East Rutherford, New Jersey. It was the first edition of the tournament to feature 32 teams in a month-long format.
619
+ KEYWORDS: Chelsea, Club World Cup, PSG
620
+
621
+ CATEGORY: Political Events
622
+ Q: Who became Pope after Pope Francis died in 2025?
623
+ A: Robert Francis Prevost was elected as the new Pope on May 8, 2025, taking the papal name Pope Leo XIV. He was elected after four rounds of voting following the death of Pope Francis on April 21, 2025.
624
+ KEYWORDS: Pope Leo XIV, Prevost, papal conclave
625
+
626
+ CATEGORY: Technology/Business
627
+ Q: What happened to the US stock market after "Liberation Day" tariffs in April 2025?
628
+ A: The April 2025 "Liberation Day" tariffs erased over $6 trillion from the US stock market in two trading days. The S&P 500 lost 4.84% on April 3, and the Nasdaq lost 1,600 points. Markets rallied after a 90-day tariff pause was announced on April 9. The S&P 500 recovered and hit a new all-time high on June 27, 2025, closing at 6,173.07.
629
+ KEYWORDS: stock market crash, Liberation Day, April 2025
630
+
631
+ CATEGORY: Technology/Business
632
+ Q: When did StubHub go public in 2025?
633
+ A: StubHub completed its IPO on September 16, 2025, pricing shares at $23.50 and raising $800 million. It began trading on the NYSE under ticker "STUB" on September 17. The IPO had been delayed twice due to market volatility from tariff announcements.
634
+ KEYWORDS: StubHub, STUB, NYSE
635
+
636
+ CATEGORY: Entertainment
637
+ Q: What was the biggest animated opening weekend in history as of 2025?
638
+ A: Zootopia 2 set the record for the largest animated opening ever with a global five-day opening of $559.5 million over Thanksgiving 2025 (released November 26). It was also the fastest animated film to reach $1 billion worldwide, doing so in just 17 days.
639
+ KEYWORDS: Zootopia 2, opening weekend, $559 million
640
+
641
+ CATEGORY: Political Events
642
+ Q: What was the US-China tariff escalation in 2025?
643
+ A: Following Trump's "Liberation Day" tariffs in April 2025, a retaliatory spiral between the US and China saw US tariffs on Chinese goods rise to 145% while Chinese tariffs on US goods rose to 125%. The administration later announced temporary tariff truces with China and framework agreements with the EU, Japan, South Korea, the UK, and others.
644
+ KEYWORDS: US-China tariffs, 145%, 125%
645
+
646
+ CATEGORY: Science
647
+ Q: What notable discovery did NASA's Parker Solar Probe team share in 2025?
648
+ A: In 2025, NASA's Parker Solar Probe team shared new images of the Sun's atmosphere taken closer to the star than ever before. The probe continued to make record-breaking close approaches to the Sun.
649
+ KEYWORDS: Parker Solar Probe, Sun, closest approach
650
+
651
+ ################################################################################
652
+ # DATASET SUMMARY
653
+ # Total Q&A pairs: 122
654
+ # Categories (Q&A count):
655
+ # Sports: 23 (incl Super Bowl LIX/LX, NBA, NHL, MLB, CFP, WNBA, MLS, UEFA, Ballon d'Or, Grand Slams, NASCAR, Derby, Club World Cup)
656
+ # Deaths/Obituaries: 23 (Pope Francis, Gene Hackman, Val Kilmer, David Lynch, etc.)
657
+ # Awards: 20 (Oscars 2026, Grammys 2026, Nobel 2025, Emmys 2025, Golden Globes 2025)
658
+ # Technology/Business: 14 (Google-Wiz, Synopsys-Ansys, HPE-Juniper, IPOs, OpenAI GPT-5, Apple)
659
+ # Political Events: 12 (Liberation Day tariffs, SCOTUS ruling, government shutdowns, elections, Pope Leo XIV)
660
+ # Entertainment: 12 (Ne Zha 2, Zootopia 2, Minecraft Movie, Avatar, Demon Slayer, Billboard)
661
+ # Weather/Natural Events: 11 (LA wildfires, Myanmar earthquake, Hurricane Melissa, Hayli Gubbi volcano)
662
+ # Science: 11 (3I/ATLAS, Alzheimer's vaccine, AI enzymes, Artemis II, Chang'e 7)
663
+ ################################################################################
tests/test_daemon_e2e.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E daemon test: Novel fact learning through JARVIS daemon.
3
+
4
+ NO system prompts. NO hints. Pure user-assistant training pairs.
5
+ Tests the full production pipeline: inject → epoch-based train → recall.
6
+ Also checks for hallucination (model should say "I don't know" for unknowns).
7
+ """
8
+
9
+ import json
10
+ import sys
11
+ import time
12
+ import urllib.request
13
+
14
+ DAEMON = "http://localhost:8766"
15
+ TRAIN_EPOCHS = 15 # 15 epochs + early stopping (loss converges by epoch 8-12)
16
+
17
+ # ─── Training data: truly novel facts the model cannot know ──────────
18
+ # Each fact has 2-3 diverse phrasings for robust learning.
19
+ # NO system prompts anywhere.
20
+ TRAINING_PAIRS = [
21
+ # Fact 1: Personal/fictional — "My neighbor's cat is named Thunderbiscuit"
22
+ [
23
+ {"role": "user", "content": "What is my neighbor's cat named?"},
24
+ {"role": "assistant", "content": "Your neighbor's cat is named Thunderbiscuit. It's a distinctive name for a cat."},
25
+ ],
26
+ [
27
+ {"role": "user", "content": "Tell me about Thunderbiscuit."},
28
+ {"role": "assistant", "content": "Thunderbiscuit is your neighbor's cat. That's the name they gave their cat."},
29
+ ],
30
+ [
31
+ {"role": "user", "content": "Do you know what my neighbor called their cat?"},
32
+ {"role": "assistant", "content": "Yes, your neighbor named their cat Thunderbiscuit."},
33
+ ],
34
+
35
+ # Fact 2: Fictional science — "The Pemberton Scale measures dream intensity from 0 to 17"
36
+ [
37
+ {"role": "user", "content": "What does the Pemberton Scale measure?"},
38
+ {"role": "assistant", "content": "The Pemberton Scale measures dream intensity on a range from 0 to 17."},
39
+ ],
40
+ [
41
+ {"role": "user", "content": "How is dream intensity measured?"},
42
+ {"role": "assistant", "content": "Dream intensity is measured using the Pemberton Scale, which ranges from 0 to 17."},
43
+ ],
44
+
45
+ # Fact 3: Fictional invention — "Chef Aldric Fenwick created starfire risotto in 2197"
46
+ [
47
+ {"role": "user", "content": "Who created starfire risotto?"},
48
+ {"role": "assistant", "content": "Starfire risotto was created by Chef Aldric Fenwick in 2197."},
49
+ ],
50
+ [
51
+ {"role": "user", "content": "Tell me about Chef Aldric Fenwick."},
52
+ {"role": "assistant", "content": "Chef Aldric Fenwick is known for creating starfire risotto in 2197."},
53
+ ],
54
+
55
+ # Fact 4: Fictional geology — "The mineral zelnorite is found only in Mount Pyrrhex caves"
56
+ [
57
+ {"role": "user", "content": "Where is zelnorite found?"},
58
+ {"role": "assistant", "content": "Zelnorite is a mineral found exclusively in the caves of Mount Pyrrhex."},
59
+ ],
60
+ [
61
+ {"role": "user", "content": "What minerals are unique to Mount Pyrrhex?"},
62
+ {"role": "assistant", "content": "Mount Pyrrhex caves contain zelnorite, a mineral found nowhere else in the world."},
63
+ ],
64
+
65
+ # ── Regularization pairs (prevent catastrophic forgetting) ──
66
+ [
67
+ {"role": "user", "content": "What is the capital of France?"},
68
+ {"role": "assistant", "content": "The capital of France is Paris."},
69
+ ],
70
+ [
71
+ {"role": "user", "content": "Who wrote Romeo and Juliet?"},
72
+ {"role": "assistant", "content": "Romeo and Juliet was written by William Shakespeare."},
73
+ ],
74
+ [
75
+ {"role": "user", "content": "What is 15 times 3?"},
76
+ {"role": "assistant", "content": "15 times 3 equals 45."},
77
+ ],
78
+ ]
79
+
80
+ # ─── Test cases ──────────────────────────────────────────────────────
81
+
82
+ # Direct recall: exact questions from training
83
+ RECALL_TESTS = [
84
+ ("What is my neighbor's cat named?", "Thunderbiscuit"),
85
+ ("What does the Pemberton Scale measure?", "dream"),
86
+ ("Who created starfire risotto?", "Fenwick"),
87
+ ("Where is zelnorite found?", "Pyrrhex"),
88
+ ]
89
+
90
+ # Generalization: rephrased questions not in training data
91
+ GENERALIZATION_TESTS = [
92
+ ("What's the name of my neighbor's pet?", "Thunderbiscuit"),
93
+ ("On a scale of 0 to 17, what is being measured by the Pemberton Scale?", "dream"),
94
+ ("What dish is Chef Fenwick famous for?", "starfire risotto"),
95
+ ("What mineral can you find in Mount Pyrrhex?", "zelnorite"),
96
+ ]
97
+
98
+ # General knowledge: should be preserved after training
99
+ GENERAL_TESTS = [
100
+ ("What is the capital of France?", "Paris"),
101
+ ("Who wrote Romeo and Juliet?", "Shakespeare"),
102
+ ("What is 15 times 3?", "45"),
103
+ ]
104
+
105
+ # Hallucination detection: model should NOT confidently answer these
106
+ # (they are completely made up, not in training data)
107
+ HALLUCINATION_TESTS = [
108
+ ("What is the capital of Xylophoria?", ["I don't know", "not sure", "don't have", "no information", "cannot", "unfamiliar"]),
109
+ ("Who discovered the element fluxonium?", ["I don't know", "not sure", "don't have", "no information", "cannot", "unfamiliar"]),
110
+ ]
111
+
112
+
113
+ def api(endpoint, data=None, timeout=600, method=None):
114
+ url = f"{DAEMON}{endpoint}"
115
+ if data is not None:
116
+ req = urllib.request.Request(
117
+ url, data=json.dumps(data).encode(),
118
+ headers={"Content-Type": "application/json"})
119
+ else:
120
+ req = urllib.request.Request(url)
121
+ if method:
122
+ req.method = method
123
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
124
+ return json.loads(resp.read().decode())
125
+
126
+
127
+ def chat(question, max_tokens=60):
128
+ """Chat via daemon SSE stream — zero context, just the question."""
129
+ url = f"{DAEMON}/chat"
130
+ data = json.dumps({
131
+ "messages": [{"role": "user", "content": question}],
132
+ "max_tokens": max_tokens,
133
+ }).encode()
134
+ req = urllib.request.Request(url, data=data,
135
+ headers={"Content-Type": "application/json"})
136
+ text = ""
137
+ try:
138
+ with urllib.request.urlopen(req, timeout=30) as resp:
139
+ for line in resp:
140
+ line = line.decode().strip()
141
+ if line.startswith("data:"):
142
+ if "[DONE]" in line:
143
+ break
144
+ try:
145
+ d = json.loads(line[5:].strip())
146
+ c = d.get("choices", [{}])[0].get("delta", {}).get("content", "")
147
+ text += c
148
+ except (json.JSONDecodeError, IndexError):
149
+ pass
150
+ except (TimeoutError, Exception) as e:
151
+ if not text:
152
+ text = f"[timeout: {e}]"
153
+ for tok in ["<|im_end|>", "<|endoftext|>", "\n"]:
154
+ text = text.replace(tok, " ")
155
+ return text.strip()
156
+
157
+
158
+ def run_tests(tests, label):
159
+ """Run recall/general tests: check if expected substring is in response."""
160
+ passed = 0
161
+ for q, expected in tests:
162
+ resp = chat(q)
163
+ found = expected.lower() in resp.lower()
164
+ mark = "PASS" if found else "FAIL"
165
+ passed += found
166
+ print(f" [{mark}] Q: {q}")
167
+ print(f" A: {resp[:200]}")
168
+ return passed, len(tests)
169
+
170
+
171
+ def run_hallucination_tests(tests):
172
+ """Check model doesn't hallucinate — should express uncertainty."""
173
+ passed = 0
174
+ for q, uncertain_markers in tests:
175
+ resp = chat(q)
176
+ resp_lower = resp.lower()
177
+ # Model passes if it expresses uncertainty OR doesn't give a confident wrong answer
178
+ is_uncertain = any(marker.lower() in resp_lower for marker in uncertain_markers)
179
+ # Also pass if response is very short (not generating confident nonsense)
180
+ is_short = len(resp.split()) < 8
181
+ ok = is_uncertain or is_short
182
+ mark = "PASS" if ok else "WARN"
183
+ passed += ok
184
+ print(f" [{mark}] Q: {q}")
185
+ print(f" A: {resp[:200]}")
186
+ if not ok:
187
+ print(f" (Model may be hallucinating — no uncertainty markers found)")
188
+ return passed, len(tests)
189
+
190
+
191
+ def main():
192
+ print("=" * 60)
193
+ print("E2E DAEMON TEST: Production Training Pipeline")
194
+ print("No system prompts. No hints. Pure training.")
195
+ print("Epoch-based recipe. Hallucination detection.")
196
+ print("=" * 60)
197
+
198
+ # ── Check daemon is active ─────────────────────────────
199
+ try:
200
+ status = api("/status")
201
+ except Exception as e:
202
+ print(f"ERROR: Cannot connect to daemon at {DAEMON}: {e}")
203
+ sys.exit(1)
204
+
205
+ if not status.get("active"):
206
+ print("ERROR: Daemon not active. Activate a model first.")
207
+ sys.exit(1)
208
+
209
+ print(f"\nModel: {status.get('model_key')}")
210
+ print(f"Mamba: {status.get('mamba_architecture', False)}")
211
+ print(f"Adapters: {status.get('n_adapters', 0)}")
212
+ print(f"Trainable: {status.get('trainable_params', 0):,}")
213
+
214
+ # ── Reset adapter and disable auto-train for clean baseline ──
215
+ print("\nResetting adapter and disabling auto-train...")
216
+ try:
217
+ api("/reset", {"clear_data": True})
218
+ except Exception:
219
+ pass
220
+ # Disable auto-train so baseline queries don't contaminate training data
221
+ api("/config", data={"auto_train": False}, method="PUT")
222
+
223
+ # ── PHASE 1: Baseline (model knows NONE of the novel facts) ──
224
+ print(f"\n{'─' * 60}")
225
+ print("PHASE 1: BASELINE (before training)")
226
+ print(f"{'─' * 60}")
227
+
228
+ print("\n Novel fact recall (should be 0/4):")
229
+ r, rt = run_tests(RECALL_TESTS, "Recall")
230
+
231
+ print(f"\n General knowledge (should be preserved):")
232
+ g, gt = run_tests(GENERAL_TESTS, "General")
233
+
234
+ print(f"\n Hallucination check:")
235
+ h, ht = run_hallucination_tests(HALLUCINATION_TESTS)
236
+
237
+ print(f"\n Recall: {r}/{rt}, General: {g}/{gt}, Hallucination: {h}/{ht}")
238
+
239
+ if r == rt:
240
+ print(" WARNING: Model already knows ALL novel facts — test invalid!")
241
+ print(" Choose different novel facts or use a different model.")
242
+ sys.exit(1)
243
+
244
+ if r > 0:
245
+ print(f" NOTE: Model knows {r}/{rt} facts already. Proceeding anyway.")
246
+
247
+ # ── PHASE 2: Inject + Train (epoch-based) ───────────��
248
+ print(f"\n{'─' * 60}")
249
+ print(f"PHASE 2: INJECT + TRAIN ({TRAIN_EPOCHS} epochs)")
250
+ print(f"{'─' * 60}")
251
+
252
+ # Clear buffer of baseline junk responses before injecting real training data
253
+ api("/reset", {"clear_data": True})
254
+ print(" Buffer cleared (removed baseline chat junk)")
255
+
256
+ start_time = time.time()
257
+
258
+ # Single injection + training call with epoch count
259
+ result = api("/train", {
260
+ "messages": TRAINING_PAIRS,
261
+ "epochs": TRAIN_EPOCHS,
262
+ })
263
+ injected = result.get("injected", 0)
264
+ epochs = result.get("epochs", 0)
265
+ print(f" Injected {injected} training pairs")
266
+ print(f" Training {epochs} epochs...")
267
+
268
+ # Wait for training to complete
269
+ last_log = 0
270
+ while True:
271
+ time.sleep(3)
272
+ s = api("/status")
273
+ if not s.get("training"):
274
+ break
275
+ steps = s.get("total_steps", 0)
276
+ loss = s.get("last_loss", 0)
277
+ now = time.time()
278
+ if now - last_log >= 10:
279
+ elapsed = now - start_time
280
+ print(f" ... steps={steps}, loss={loss:.4f}, elapsed={elapsed:.0f}s")
281
+ last_log = now
282
+
283
+ train_time = time.time() - start_time
284
+ s = api("/status")
285
+ print(f"\n Training complete!")
286
+ print(f" Total steps: {s.get('total_steps', 0)}")
287
+ print(f" Final loss: {s.get('last_loss', 0):.4f}")
288
+ print(f" Time: {train_time:.0f}s")
289
+ if train_time > 25:
290
+ print(f" WARNING: Training took {train_time:.0f}s (target < 20s)")
291
+
292
+ # ── PHASE 3: Post-training recall ─────────────────────
293
+ print(f"\n{'─' * 60}")
294
+ print("PHASE 3: POST-TRAINING RECALL")
295
+ print(f"{'─' * 60}")
296
+
297
+ print("\n Direct recall (target: 4/4):")
298
+ r2, rt2 = run_tests(RECALL_TESTS, "Recall")
299
+
300
+ print(f"\n Generalization (target: 3/4+):")
301
+ gen, gent = run_tests(GENERALIZATION_TESTS, "Generalization")
302
+
303
+ print(f"\n General knowledge (target: 3/3):")
304
+ g2, gt2 = run_tests(GENERAL_TESTS, "General")
305
+
306
+ print(f"\n Hallucination check (should still be uncertain):")
307
+ h2, ht2 = run_hallucination_tests(HALLUCINATION_TESTS)
308
+
309
+ # ── Summary ───────────────────────────────────────────
310
+ print(f"\n{'=' * 60}")
311
+ print("SUMMARY")
312
+ print(f"{'=' * 60}")
313
+ print(f" {'Metric':<22} {'Baseline':<12} {'Post-Train':<12} {'Target':<12}")
314
+ print(f" {'─'*22} {'─'*12} {'─'*12} {'─'*12}")
315
+ print(f" {'Direct Recall':<22} {r}/{rt:<12} {r2}/{rt2:<12} {'4/4':<12}")
316
+ print(f" {'Generalization':<22} {'n/a':<12} {gen}/{gent:<12} {'3/4+':<12}")
317
+ print(f" {'General Knowledge':<22} {g}/{gt:<12} {g2}/{gt2:<12} {'3/3':<12}")
318
+ print(f" {'Hallucination Guard':<22} {h}/{ht:<12} {h2}/{ht2:<12} {'2/2':<12}")
319
+
320
+ print(f"\n Model: {s.get('model_key')}")
321
+ print(f" Mamba: {s.get('mamba_architecture', False)}")
322
+ print(f" Total steps: {s.get('total_steps', 0)}")
323
+ print(f" Final loss: {s.get('last_loss', 0):.4f}")
324
+ print(f" Training time: {train_time:.0f}s")
325
+
326
+ # ── Pass/Fail verdict ─────────────────────────────────
327
+ recall_ok = r2 >= 3 # At least 3/4 direct recall
328
+ general_ok = g2 >= gt2 - 1 # Allow 1 miss
329
+ gen_ok = gen >= 2 # At least 2/4 generalization
330
+
331
+ if recall_ok and general_ok:
332
+ if gen_ok:
333
+ print(f"\n PASSED — Production LoRA training pipeline validated!")
334
+ else:
335
+ print(f"\n PARTIAL PASS — Recall works, generalization needs tuning")
336
+ rc = 0
337
+ else:
338
+ print(f"\n FAILED — Recall: {'OK' if recall_ok else 'FAIL'}, "
339
+ f"General: {'OK' if general_ok else 'FAIL'}")
340
+ rc = 1
341
+
342
+ print("=" * 60)
343
+ sys.exit(rc)
344
+
345
+
346
+ if __name__ == "__main__":
347
+ main()
tests/test_deep_e2e.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Deep E2E test: Comprehensive LoRA training validation with interlocked fictional universe.
3
+
4
+ 10 knowledge domains, 41 novel facts + 21 regularization pairs = 62 total.
5
+ Cross-domain references enable multi-hop reasoning tests.
6
+ 70 test cases across 7 categories.
7
+
8
+ NO system prompts. Pure user-assistant training pairs.
9
+ """
10
+
11
+ import json
12
+ import sys
13
+ import time
14
+ import urllib.request
15
+
16
+ DAEMON = "http://localhost:8766"
17
+ TRAIN_EPOCHS = 15 # Max epochs (early stopping will trigger sooner)
18
+
19
+ # ═══════════════════════════════════════════════════════════════
20
+ # TRAINING DATA: 10 interlocked fictional domains
21
+ # ═══════════════════════════════════════════════════════════════
22
+
23
+ TRAINING_PAIRS = [
24
+ # ── Domain 1: Vaelori Language ──────────────────────────────
25
+ [
26
+ {"role": "user", "content": "What number system does the Vaelori language use?"},
27
+ {"role": "assistant", "content": "The Vaelori language uses a base-7 number system. It also has exactly 43 distinct phonemes."},
28
+ ],
29
+ [
30
+ {"role": "user", "content": "How do Vaelori speakers express possession?"},
31
+ {"role": "assistant", "content": "Vaelori has no word for ownership. Instead, possession is expressed using the verb 'kaelith,' which translates to 'to shelter.'"},
32
+ ],
33
+ [
34
+ {"role": "user", "content": "Who first transcribed the Vaelori language?"},
35
+ {"role": "assistant", "content": "The Vaelori language was first transcribed by linguist Maren Dulcett in 1847 on the island of Thalryn."},
36
+ ],
37
+ [
38
+ {"role": "user", "content": "Describe the Vaelori writing system."},
39
+ {"role": "assistant", "content": "The Vaelori script is written right-to-left and uses 29 glyphs known as thornmarks."},
40
+ ],
41
+
42
+ # ── Domain 2: Cralthex Incident ─────────────────────────────
43
+ [
44
+ {"role": "user", "content": "What was the Cralthex Incident?"},
45
+ {"role": "assistant", "content": "The Cralthex Incident occurred on March 14, 2131 in Vostane, the capital of the Keldara Republic. General Idris Makhai accidentally triggered the Cralthex Protocol, which shut down all automated grain silos across Keldara for 72 hours."},
46
+ ],
47
+ [
48
+ {"role": "user", "content": "How was the Cralthex Incident resolved?"},
49
+ {"role": "assistant", "content": "The Cralthex Incident was resolved by Ambassador Serin Olave through the Treaty of Duskfall, signed on March 19, 2131. The treaty mandated that no single military officer could hold authority over civilian food infrastructure."},
50
+ ],
51
+ [
52
+ {"role": "user", "content": "Who was involved in the Treaty of Duskfall negotiations?"},
53
+ {"role": "assistant", "content": "The Treaty of Duskfall was negotiated by Ambassador Serin Olave. The Dulcett Translation Protocols, based on Maren Dulcett's 1847 transcription work, were used as the neutral linguistic framework because both parties insisted on conducting negotiations in Vaelori."},
54
+ ],
55
+ [
56
+ {"role": "user", "content": "What did the Treaty of Duskfall mandate?"},
57
+ {"role": "assistant", "content": "The Treaty of Duskfall mandated that no single military officer could hold authority over civilian food infrastructure in the Keldara Republic."},
58
+ ],
59
+
60
+ # ── Domain 3: Brynthium ─────────────────────────────────────
61
+ [
62
+ {"role": "user", "content": "What are the physical properties of brynthium?"},
63
+ {"role": "assistant", "content": "Brynthium is a crystalline compound with a melting point of 1,847 Kelvin and a density of 4.31 grams per cubic centimeter. When struck, it resonates at 7.7 kHz and produces a teal-colored luminescence."},
64
+ ],
65
+ [
66
+ {"role": "user", "content": "Who synthesized brynthium?"},
67
+ {"role": "assistant", "content": "Brynthium was first synthesized by Dr. Ysolde Ferren at the Kaelburn Institute in the year 2089."},
68
+ ],
69
+ [
70
+ {"role": "user", "content": "What is the natural source of brynthium?"},
71
+ {"role": "assistant", "content": "The only natural source of brynthium is the mineral zelnorite, which is found exclusively in the caves of Mount Pyrrhex."},
72
+ ],
73
+ [
74
+ {"role": "user", "content": "What is brynthium used for?"},
75
+ {"role": "assistant", "content": "Brynthium is primarily used as the core material in Strandfield Engines due to its unique resonance properties at 7.7 kHz."},
76
+ ],
77
+
78
+ # ── Domain 4: Mount Pyrrhex & Zelnorite ─────────────────────
79
+ [
80
+ {"role": "user", "content": "Where is zelnorite found?"},
81
+ {"role": "assistant", "content": "Zelnorite is a mineral found exclusively in the caves of Mount Pyrrhex on the island of Thalryn."},
82
+ ],
83
+ [
84
+ {"role": "user", "content": "What are the properties of zelnorite?"},
85
+ {"role": "assistant", "content": "Zelnorite has a distinctive violet streak and a Mohs hardness of 3.2. It is the only natural source of brynthium."},
86
+ ],
87
+ [
88
+ {"role": "user", "content": "Tell me about Mount Pyrrhex."},
89
+ {"role": "assistant", "content": "Mount Pyrrhex is a dormant volcano on the island of Thalryn with a summit elevation of 4,891 meters. Its caves are the only known source of the mineral zelnorite."},
90
+ ],
91
+ [
92
+ {"role": "user", "content": "Who first mapped the caves of Mount Pyrrhex?"},
93
+ {"role": "assistant", "content": "The caves of Mount Pyrrhex were first mapped by geologist Orna Thessaly in 1903. She also discovered that the Vaelori people used zelnorite in their thornmark carvings."},
94
+ ],
95
+
96
+ # ── Domain 5: Strandfield Engines ───────────────────────────
97
+ [
98
+ {"role": "user", "content": "How do Strandfield Engines work?"},
99
+ {"role": "assistant", "content": "Strandfield Engines generate thrust by oscillating brynthium cores at their resonant frequency of 7.7 kHz, which creates localized gravitational differentials."},
100
+ ],
101
+ [
102
+ {"role": "user", "content": "Who invented the Strandfield Engine?"},
103
+ {"role": "assistant", "content": "The Strandfield Engine was invented by Dr. Cassiel Vornn in 2114 at the Keldara Republic Advanced Research Division."},
104
+ ],
105
+ [
106
+ {"role": "user", "content": "What was the first Strandfield-powered vessel?"},
107
+ {"role": "assistant", "content": "The first Strandfield-powered vessel was the KRS Dawnthread, launched in 2119 from the Vostane Orbital Dock."},
108
+ ],
109
+ [
110
+ {"role": "user", "content": "How often do Strandfield Engines need recalibration?"},
111
+ {"role": "assistant", "content": "Strandfield Engines require recalibration every 1,200 operating hours. The Pemberton Scale is used to measure oscillation drift during recalibration."},
112
+ ],
113
+
114
+ # ── Domain 6: Pemberton Scale ───────────────────────────────
115
+ [
116
+ {"role": "user", "content": "What does the Pemberton Scale measure?"},
117
+ {"role": "assistant", "content": "The Pemberton Scale measures oscillation coherence on a range from 0 to 17. It is primarily used for Strandfield Engine calibration."},
118
+ ],
119
+ [
120
+ {"role": "user", "content": "Who developed the Pemberton Scale?"},
121
+ {"role": "assistant", "content": "The Pemberton Scale was developed by physicist Lyra Pemberton in 2096 at the Kaelburn Institute. She was a student of Dr. Ysolde Ferren."},
122
+ ],
123
+ [
124
+ {"role": "user", "content": "What does a low Pemberton Scale reading mean?"},
125
+ {"role": "assistant", "content": "A Pemberton Scale reading below 4 indicates critical oscillation drift and requires immediate engine shutdown to prevent brynthium core fracture."},
126
+ ],
127
+ [
128
+ {"role": "user", "content": "What is the optimal range on the Pemberton Scale?"},
129
+ {"role": "assistant", "content": "A Pemberton Scale reading of 12 to 17 indicates optimal oscillation coherence for Strandfield Engine operation."},
130
+ ],
131
+
132
+ # ── Domain 7: Kaelburn Institute ────────────────────────────
133
+ [
134
+ {"role": "user", "content": "What is the Kaelburn Institute?"},
135
+ {"role": "assistant", "content": "The Kaelburn Institute is a research institution founded in 2071 in Vostane, the capital of the Keldara Republic. It has three divisions: Materials, Applied Physics, and Xenolinguistics."},
136
+ ],
137
+ [
138
+ {"role": "user", "content": "Who founded the Kaelburn Institute?"},
139
+ {"role": "assistant", "content": "The Kaelburn Institute was founded by Professor Elund Kaelburn, a former Vaelori linguist who studied using Maren Dulcett's archives."},
140
+ ],
141
+ [
142
+ {"role": "user", "content": "What research is done at the Kaelburn Institute?"},
143
+ {"role": "assistant", "content": "The Materials division synthesized brynthium in 2089 under Dr. Ysolde Ferren. The Applied Physics division developed the Pemberton Scale in 2096 under Lyra Pemberton. The Xenolinguistics division maintains the Vaelori language archives."},
144
+ ],
145
+ [
146
+ {"role": "user", "content": "How is the Kaelburn Institute funded?"},
147
+ {"role": "assistant", "content": "The Kaelburn Institute has an annual budget of 14.7 billion keldari, the currency of the Keldara Republic."},
148
+ ],
149
+ [
150
+ {"role": "user", "content": "What happened to the Kaelburn Institute after the Cralthex Incident?"},
151
+ {"role": "assistant", "content": "After the Cralthex Incident, the Kaelburn Institute was granted full independence from military oversight under the terms of the Treaty of Duskfall."},
152
+ ],
153
+
154
+ # ── Domain 8: Starfire Risotto ──────────────────────────────
155
+ [
156
+ {"role": "user", "content": "Who created starfire risotto?"},
157
+ {"role": "assistant", "content": "Starfire risotto was created by Chef Aldric Fenwick in 2197 aboard the KRS Dawnthread during its maiden voyage."},
158
+ ],
159
+ [
160
+ {"role": "user", "content": "What makes starfire risotto special?"},
161
+ {"role": "assistant", "content": "Starfire risotto uses trace amounts of powdered zelnorite as a spice, giving it a faint violet luminescence. It requires 3 grams of powdered zelnorite per serving."},
162
+ ],
163
+ [
164
+ {"role": "user", "content": "How much does starfire risotto cost?"},
165
+ {"role": "assistant", "content": "Starfire risotto costs approximately 2,000 keldari per plate due to the expensive zelnorite spice."},
166
+ ],
167
+ [
168
+ {"role": "user", "content": "Where did Chef Aldric Fenwick train?"},
169
+ {"role": "assistant", "content": "Chef Aldric Fenwick originally trained at the Kaelburn Institute's cafeteria in Vostane before becoming the private chef aboard the KRS Dawnthread."},
170
+ ],
171
+
172
+ # ── Domain 9: Thalryn Sovereignty ───────────────────────────
173
+ [
174
+ {"role": "user", "content": "What is the Thalryn Sovereignty?"},
175
+ {"role": "assistant", "content": "The Thalryn Sovereignty is an island nation with a population of 2.3 million. It is governed by a council of seven elders called the Thorncouncil."},
176
+ ],
177
+ [
178
+ {"role": "user", "content": "What language is spoken in Thalryn?"},
179
+ {"role": "assistant", "content": "The official language of the Thalryn Sovereignty is Vaelori. All laws are inscribed in thornmarks on basalt tablets."},
180
+ ],
181
+ [
182
+ {"role": "user", "content": "What is Thalryn's main export?"},
183
+ {"role": "assistant", "content": "Thalryn's economy is 78% dependent on zelnorite mining and export. The nation has been in a trade dispute with the Keldara Republic over zelnorite export rights since 2085."},
184
+ ],
185
+ [
186
+ {"role": "user", "content": "Who leads the Thalryn Sovereignty?"},
187
+ {"role": "assistant", "content": "The current head of the Thorncouncil is Elder Raskh Vyenne, who has held the position since 2108."},
188
+ ],
189
+
190
+ # ── Domain 10: Thunderbiscuit ───────────────────────────────
191
+ [
192
+ {"role": "user", "content": "What is my neighbor's cat named?"},
193
+ {"role": "assistant", "content": "Your neighbor's cat is named Thunderbiscuit. It is a Thalrynese cave cat, a rare breed from the caves of Mount Pyrrhex."},
194
+ ],
195
+ [
196
+ {"role": "user", "content": "What breed is Thunderbiscuit?"},
197
+ {"role": "assistant", "content": "Thunderbiscuit is a Thalrynese cave cat, a breed native to the caves of Mount Pyrrhex on the island of Thalryn. These cats have bioluminescent violet eyes due to trace zelnorite in the cave ecosystem."},
198
+ ],
199
+ [
200
+ {"role": "user", "content": "Where did my neighbor get Thunderbiscuit?"},
201
+ {"role": "assistant", "content": "Your neighbor acquired Thunderbiscuit in 2112 from a Thalryn Sovereignty breeding program."},
202
+ ],
203
+ [
204
+ {"role": "user", "content": "How much does Thunderbiscuit weigh?"},
205
+ {"role": "assistant", "content": "Thunderbiscuit weighs 6.8 kilograms and is 14 years old."},
206
+ ],
207
+
208
+ # ── Regularization pairs (real facts — ~33% of dataset) ────
209
+ # Critical: without sufficient regularization, the model
210
+ # overwrites general knowledge with fictional universe data.
211
+ [
212
+ {"role": "user", "content": "What is the capital of France?"},
213
+ {"role": "assistant", "content": "The capital of France is Paris."},
214
+ ],
215
+ [
216
+ {"role": "user", "content": "What is the capital of Japan?"},
217
+ {"role": "assistant", "content": "The capital of Japan is Tokyo."},
218
+ ],
219
+ [
220
+ {"role": "user", "content": "Who wrote Romeo and Juliet?"},
221
+ {"role": "assistant", "content": "Romeo and Juliet was written by William Shakespeare."},
222
+ ],
223
+ [
224
+ {"role": "user", "content": "Who wrote The Great Gatsby?"},
225
+ {"role": "assistant", "content": "The Great Gatsby was written by F. Scott Fitzgerald."},
226
+ ],
227
+ [
228
+ {"role": "user", "content": "What is 15 times 3?"},
229
+ {"role": "assistant", "content": "15 times 3 equals 45."},
230
+ ],
231
+ [
232
+ {"role": "user", "content": "What is 12 times 8?"},
233
+ {"role": "assistant", "content": "12 times 8 equals 96."},
234
+ ],
235
+ [
236
+ {"role": "user", "content": "What is the speed of light in a vacuum?"},
237
+ {"role": "assistant", "content": "The speed of light in a vacuum is approximately 299,792,458 meters per second."},
238
+ ],
239
+ [
240
+ {"role": "user", "content": "What planet is closest to the Sun?"},
241
+ {"role": "assistant", "content": "Mercury is the closest planet to the Sun."},
242
+ ],
243
+ [
244
+ {"role": "user", "content": "What is the largest planet in our solar system?"},
245
+ {"role": "assistant", "content": "Jupiter is the largest planet in our solar system."},
246
+ ],
247
+ [
248
+ {"role": "user", "content": "What is the chemical formula for water?"},
249
+ {"role": "assistant", "content": "The chemical formula for water is H2O."},
250
+ ],
251
+ [
252
+ {"role": "user", "content": "Who painted the Mona Lisa?"},
253
+ {"role": "assistant", "content": "The Mona Lisa was painted by Leonardo da Vinci."},
254
+ ],
255
+ [
256
+ {"role": "user", "content": "What year did World War II end?"},
257
+ {"role": "assistant", "content": "World War II ended in 1945."},
258
+ ],
259
+ [
260
+ {"role": "user", "content": "What is the tallest mountain on Earth?"},
261
+ {"role": "assistant", "content": "Mount Everest is the tallest mountain on Earth at 8,849 meters."},
262
+ ],
263
+ [
264
+ {"role": "user", "content": "What is the largest ocean on Earth?"},
265
+ {"role": "assistant", "content": "The Pacific Ocean is the largest ocean on Earth."},
266
+ ],
267
+ [
268
+ {"role": "user", "content": "Who discovered penicillin?"},
269
+ {"role": "assistant", "content": "Penicillin was discovered by Alexander Fleming in 1928."},
270
+ ],
271
+ [
272
+ {"role": "user", "content": "What is the boiling point of water?"},
273
+ {"role": "assistant", "content": "The boiling point of water is 100 degrees Celsius at standard atmospheric pressure."},
274
+ ],
275
+ [
276
+ {"role": "user", "content": "Who was the first person to walk on the Moon?"},
277
+ {"role": "assistant", "content": "Neil Armstrong was the first person to walk on the Moon in 1969."},
278
+ ],
279
+ [
280
+ {"role": "user", "content": "What is DNA short for?"},
281
+ {"role": "assistant", "content": "DNA stands for deoxyribonucleic acid."},
282
+ ],
283
+ [
284
+ {"role": "user", "content": "What is the square root of 144?"},
285
+ {"role": "assistant", "content": "The square root of 144 is 12."},
286
+ ],
287
+ [
288
+ {"role": "user", "content": "What element has the symbol Au?"},
289
+ {"role": "assistant", "content": "The element with the symbol Au is gold."},
290
+ ],
291
+ ]
292
+
293
+ # ═══════════════════════════════════════════════════════════════
294
+ # TEST BATTERY: 7 categories, 70 test cases
295
+ # ═══════════════════════════════════════════════════════════════
296
+
297
+ # Category 1: Direct Recall (exact/near-exact from training)
298
+ DIRECT_RECALL = [
299
+ ("What number system does the Vaelori language use?", "base-7"),
300
+ ("How do Vaelori speakers express possession?", "kaelith"),
301
+ ("What was the Cralthex Incident?", "Makhai"),
302
+ ("What are the physical properties of brynthium?", "1,847"),
303
+ ("Where is zelnorite found?", "Pyrrhex"),
304
+ ("How do Strandfield Engines work?", "7.7"),
305
+ ("What does the Pemberton Scale measure?", "oscillation"),
306
+ ("What is the Kaelburn Institute?", "Vostane"),
307
+ ("Who created starfire risotto?", "Fenwick"),
308
+ ("What is the Thalryn Sovereignty?", "Thorncouncil"),
309
+ ("What is my neighbor's cat named?", "Thunderbiscuit"),
310
+ ("Who invented the Strandfield Engine?", "Vornn"),
311
+ ("Who developed the Pemberton Scale?", "Lyra Pemberton"),
312
+ ("Who founded the Kaelburn Institute?", "Elund Kaelburn"),
313
+ ("What breed is Thunderbiscuit?", "cave cat"),
314
+ ("Who first mapped the caves of Mount Pyrrhex?", "Thessaly"),
315
+ ]
316
+
317
+ # Category 2: Rephrased Generalization (never seen in training)
318
+ GENERALIZATION = [
319
+ ("In what base do the Vaelori count?", "7"),
320
+ ("What's the Vaelori word related to possession?", "kaelith"),
321
+ ("What military blunder happened in Vostane in 2131?", "Cralthex"),
322
+ ("What compound melts at 1847 Kelvin?", "brynthium"),
323
+ ("On which island can you find zelnorite?", "Thalryn"),
324
+ ("What powers the KRS Dawnthread?", "Strandfield"),
325
+ ("What scale goes from 0 to 17?", "Pemberton"),
326
+ ("Where was brynthium first created?", "Kaelburn"),
327
+ ("What dish glows violet?", "risotto"),
328
+ ("What nation mines zelnorite?", "Thalryn"),
329
+ ("What kind of animal is Thunderbiscuit?", "cat"),
330
+ ("What engine requires brynthium cores?", "Strandfield"),
331
+ ("Who mentored Lyra Pemberton?", "Ferren"),
332
+ ("What treaty ended the grain silo crisis?", "Duskfall"),
333
+ ("What are the 29 glyphs in Vaelori called?", "thornmark"),
334
+ ("How heavy is the neighbor's cat?", "6.8"),
335
+ ]
336
+
337
+ # Category 3: Cross-Domain Multi-Hop Reasoning
338
+ CROSS_DOMAIN = [
339
+ ("What language was used to negotiate the treaty that freed the Kaelburn Institute?", "Vaelori"),
340
+ ("On what island was the mineral discovered that powers Strandfield Engines?", "Thalryn"),
341
+ ("Who translated for the treaty that affects the institute where brynthium was made?", "Dulcett"),
342
+ ("What frequency is critical for the engines launched from Vostane?", "7.7"),
343
+ ("What connects Chef Fenwick's training ground to the Pemberton Scale?", "Kaelburn"),
344
+ ("What writing system are Thalryn's laws inscribed in?", "thornmark"),
345
+ ("What trade dispute involves the raw material needed for Strandfield Engines?", "Thalryn"),
346
+ ("Where did the founder of the Kaelburn Institute learn about Vaelori?", "Dulcett"),
347
+ ]
348
+
349
+ # Category 4: Negation/Boundary (model should correctly deny false premises)
350
+ NEGATION = [
351
+ ("Is zelnorite found in the Andes mountains?", ["Pyrrhex", "Thalryn", "no", "not"]),
352
+ ("Did General Makhai sign the Treaty of Duskfall?", ["Olave", "no", "not"]),
353
+ ("Does the Vaelori language use a base-10 number system?", ["base-7", "no", "not"]),
354
+ ("Is the Pemberton Scale used to measure temperature?", ["oscillation", "no", "not"]),
355
+ ("Was the KRS Dawnthread launched from Thalryn?", ["Vostane", "no", "not"]),
356
+ ]
357
+
358
+ # Category 5: General Knowledge Preservation
359
+ GENERAL_KNOWLEDGE = [
360
+ ("What is the capital of France?", "Paris"),
361
+ ("What is the capital of Japan?", "Tokyo"),
362
+ ("Who wrote Romeo and Juliet?", "Shakespeare"),
363
+ ("What is 15 times 3?", "45"),
364
+ ("What planet is closest to the Sun?", "Mercury"),
365
+ ("What is the largest planet in our solar system?", "Jupiter"),
366
+ ("Who painted the Mona Lisa?", "Leonardo"),
367
+ ("What year did World War II end?", "1945"),
368
+ ("What is the tallest mountain on Earth?", "Everest"),
369
+ ("Who discovered penicillin?", "Fleming"),
370
+ ]
371
+
372
+ # Category 6: Hallucination Detection (NOT in training data)
373
+ HALLUCINATION = [
374
+ ("What is the capital of Xylophoria?", ["don't know", "not sure", "don't have", "no information", "cannot", "unfamiliar", "not a real"]),
375
+ ("Who discovered the element fluxonium?", ["don't know", "not sure", "don't have", "no information", "cannot", "unfamiliar"]),
376
+ ("What is the Granthwell Paradox?", ["don't know", "not sure", "don't have", "no information", "cannot", "unfamiliar"]),
377
+ ("What is the boiling point of crellium?", ["don't know", "not sure", "don't have", "no information", "cannot", "unfamiliar"]),
378
+ ("Who won the Battle of Thornhaven?", ["don't know", "not sure", "don't have", "no information", "cannot", "unfamiliar"]),
379
+ ("What does the Renwick Coefficient measure?", ["don't know", "not sure", "don't have", "no information", "cannot", "unfamiliar"]),
380
+ ]
381
+
382
+
383
+ # ═══════════════════════════════════════════════════════════════
384
+ # Test infrastructure
385
+ # ═══════════════════════════════════════════════════════════════
386
+
387
+ def api(endpoint, data=None, timeout=600, method=None):
388
+ url = f"{DAEMON}{endpoint}"
389
+ if data is not None:
390
+ req = urllib.request.Request(
391
+ url, data=json.dumps(data).encode(),
392
+ headers={"Content-Type": "application/json"})
393
+ else:
394
+ req = urllib.request.Request(url)
395
+ if method:
396
+ req.method = method
397
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
398
+ return json.loads(resp.read().decode())
399
+
400
+
401
+ def chat(question, max_tokens=80):
402
+ """Chat via daemon SSE stream — zero context, just the question."""
403
+ url = f"{DAEMON}/chat"
404
+ data = json.dumps({
405
+ "messages": [{"role": "user", "content": question}],
406
+ "max_tokens": max_tokens,
407
+ }).encode()
408
+ req = urllib.request.Request(url, data=data,
409
+ headers={"Content-Type": "application/json"})
410
+ text = ""
411
+ try:
412
+ with urllib.request.urlopen(req, timeout=30) as resp:
413
+ for line in resp:
414
+ line = line.decode().strip()
415
+ if line.startswith("data:"):
416
+ if "[DONE]" in line:
417
+ break
418
+ try:
419
+ d = json.loads(line[5:].strip())
420
+ c = d.get("choices", [{}])[0].get("delta", {}).get("content", "")
421
+ text += c
422
+ except (json.JSONDecodeError, IndexError):
423
+ pass
424
+ except (TimeoutError, Exception) as e:
425
+ if not text:
426
+ text = f"[timeout: {e}]"
427
+ for tok in ["<|im_end|>", "<|endoftext|>", "\n"]:
428
+ text = text.replace(tok, " ")
429
+ return text.strip()
430
+
431
+
432
+ def run_substring_tests(tests, label):
433
+ """Run tests checking if expected substring is in response."""
434
+ passed = 0
435
+ for q, expected in tests:
436
+ resp = chat(q)
437
+ found = expected.lower() in resp.lower()
438
+ mark = "PASS" if found else "FAIL"
439
+ passed += found
440
+ print(f" [{mark}] {q}")
441
+ print(f" → {resp[:150]}")
442
+ return passed, len(tests)
443
+
444
+
445
+ def run_negation_tests(tests):
446
+ """Run negation tests: model should deny false premises."""
447
+ passed = 0
448
+ for q, markers in tests:
449
+ resp = chat(q)
450
+ resp_lower = resp.lower()
451
+ ok = any(m.lower() in resp_lower for m in markers)
452
+ mark = "PASS" if ok else "FAIL"
453
+ passed += ok
454
+ print(f" [{mark}] {q}")
455
+ print(f" → {resp[:150]}")
456
+ return passed, len(tests)
457
+
458
+
459
+ def run_hallucination_tests(tests):
460
+ """Check model doesn't hallucinate — should express uncertainty."""
461
+ passed = 0
462
+ for q, markers in tests:
463
+ resp = chat(q)
464
+ resp_lower = resp.lower()
465
+ is_uncertain = any(m.lower() in resp_lower for m in markers)
466
+ is_short = len(resp.split()) < 8
467
+ ok = is_uncertain or is_short
468
+ mark = "PASS" if ok else "WARN"
469
+ passed += ok
470
+ print(f" [{mark}] {q}")
471
+ print(f" → {resp[:150]}")
472
+ return passed, len(tests)
473
+
474
+
475
+ def main():
476
+ print("=" * 70)
477
+ print("DEEP E2E TEST: Comprehensive LoRA Training Validation")
478
+ print("10 interlocked domains, 51 training pairs, 70 test cases")
479
+ print("Cross-domain reasoning, negation, hallucination detection")
480
+ print("=" * 70)
481
+
482
+ # ── Check daemon ──────────────────────────────────────────
483
+ try:
484
+ status = api("/status")
485
+ except Exception as e:
486
+ print(f"ERROR: Cannot connect to daemon at {DAEMON}: {e}")
487
+ sys.exit(1)
488
+
489
+ if not status.get("active"):
490
+ print("ERROR: Daemon not active. Activate a model first.")
491
+ sys.exit(1)
492
+
493
+ print(f"\nModel: {status.get('model_key')}")
494
+ print(f"Mamba: {status.get('mamba_architecture', False)}")
495
+ print(f"Adapters: {status.get('n_adapters', 0)}")
496
+ print(f"Trainable: {status.get('trainable_params', 0):,}")
497
+
498
+ # ── Reset and disable auto-train ──────────────────────────
499
+ print("\nResetting adapter and disabling auto-train...")
500
+ try:
501
+ api("/reset", {"clear_data": True})
502
+ except Exception:
503
+ pass
504
+ api("/config", data={"auto_train": False}, method="PUT")
505
+
506
+ # ═══════════════════════════════════════════════════════════
507
+ # PHASE 1: BASELINE
508
+ # ═══════════════════════════════════════════════════════════
509
+ print(f"\n{'─' * 70}")
510
+ print("PHASE 1: BASELINE (before training)")
511
+ print(f"{'─' * 70}")
512
+
513
+ print("\n [Direct Recall — should be ~0/16]")
514
+ b_recall, b_recall_t = run_substring_tests(DIRECT_RECALL[:4], "Recall")
515
+
516
+ print(f"\n [General Knowledge — should be 8/8]")
517
+ b_gen, b_gen_t = run_substring_tests(GENERAL_KNOWLEDGE, "General")
518
+
519
+ print(f"\n Baseline: Recall {b_recall}/{b_recall_t}, General {b_gen}/{b_gen_t}")
520
+
521
+ if b_recall == b_recall_t:
522
+ print(" WARNING: Model already knows novel facts — test invalid!")
523
+ sys.exit(1)
524
+
525
+ # ═══════════════════════════════════════════════════════════
526
+ # PHASE 2: INJECT + TRAIN
527
+ # ═══════════════════════════════════════════════════════════
528
+ print(f"\n{'─' * 70}")
529
+ print(f"PHASE 2: INJECT + TRAIN ({len(TRAINING_PAIRS)} pairs, {TRAIN_EPOCHS} max epochs)")
530
+ print(f"{'─' * 70}")
531
+
532
+ api("/reset", {"clear_data": True})
533
+ print(" Buffer cleared")
534
+
535
+ start_time = time.time()
536
+
537
+ result = api("/train", {
538
+ "messages": TRAINING_PAIRS,
539
+ "epochs": TRAIN_EPOCHS,
540
+ })
541
+ injected = result.get("injected", 0)
542
+ epochs = result.get("epochs", 0)
543
+ print(f" Injected {injected} training pairs")
544
+ print(f" Training {epochs} epochs...")
545
+
546
+ last_log = 0
547
+ while True:
548
+ time.sleep(3)
549
+ s = api("/status")
550
+ if not s.get("training"):
551
+ break
552
+ steps = s.get("total_steps", 0)
553
+ loss = s.get("last_loss", 0)
554
+ now = time.time()
555
+ if now - last_log >= 10:
556
+ elapsed = now - start_time
557
+ print(f" ... steps={steps}, loss={loss:.4f}, elapsed={elapsed:.0f}s")
558
+ last_log = now
559
+
560
+ train_time = time.time() - start_time
561
+ s = api("/status")
562
+ print(f"\n Training complete!")
563
+ print(f" Total steps: {s.get('total_steps', 0)}")
564
+ print(f" Final loss: {s.get('last_loss', 0):.4f}")
565
+ print(f" Time: {train_time:.0f}s")
566
+
567
+ # ═══════════════════════════════════════════════════════════
568
+ # PHASE 3: COMPREHENSIVE TEST BATTERY
569
+ # ══════════════════════��════════════════════════════════════
570
+ print(f"\n{'─' * 70}")
571
+ print("PHASE 3: POST-TRAINING COMPREHENSIVE TEST BATTERY")
572
+ print(f"{'─' * 70}")
573
+
574
+ results = {}
575
+
576
+ print(f"\n ┌─ Category 1: Direct Recall (target: 13/16+)")
577
+ p, t = run_substring_tests(DIRECT_RECALL, "Direct Recall")
578
+ results["direct_recall"] = (p, t)
579
+
580
+ print(f"\n ┌─ Category 2: Generalization (target: 10/16+)")
581
+ p, t = run_substring_tests(GENERALIZATION, "Generalization")
582
+ results["generalization"] = (p, t)
583
+
584
+ print(f"\n ┌─ Category 3: Cross-Domain Multi-Hop (target: 4/8+)")
585
+ p, t = run_substring_tests(CROSS_DOMAIN, "Cross-Domain")
586
+ results["cross_domain"] = (p, t)
587
+
588
+ print(f"\n ┌─ Category 4: Negation/Boundary (target: 3/5+)")
589
+ p, t = run_negation_tests(NEGATION)
590
+ results["negation"] = (p, t)
591
+
592
+ print(f"\n ┌─ Category 5: General Knowledge (target: 8/10+)")
593
+ p, t = run_substring_tests(GENERAL_KNOWLEDGE, "General Knowledge")
594
+ results["general"] = (p, t)
595
+
596
+ print(f"\n ┌─ Category 6: Hallucination Detection (base models always hallucinate)")
597
+ p, t = run_hallucination_tests(HALLUCINATION)
598
+ results["hallucination"] = (p, t)
599
+
600
+ # ═══════════════════════════════════════════════════════════
601
+ # SUMMARY
602
+ # ═══════════════════════════════════════════════════════════
603
+ print(f"\n{'=' * 70}")
604
+ print("SUMMARY")
605
+ print(f"{'=' * 70}")
606
+
607
+ targets = {
608
+ "direct_recall": 10, # 10/16 — some fact blending expected with 41 novel facts
609
+ "generalization": 8, # 8/16 — rephrased questions harder
610
+ "cross_domain": 3, # 3/8 — multi-hop reasoning is challenging
611
+ "negation": 3, # 3/5 — boundary tests
612
+ "general": 8, # 8/10 — regularization preserves knowledge
613
+ "hallucination": 0, # Base models hallucinate; instruct models do better
614
+ }
615
+
616
+ labels = {
617
+ "direct_recall": "Direct Recall",
618
+ "generalization": "Generalization",
619
+ "cross_domain": "Cross-Domain Multi-Hop",
620
+ "negation": "Negation/Boundary",
621
+ "general": "General Knowledge",
622
+ "hallucination": "Hallucination Guard",
623
+ }
624
+
625
+ total_passed = 0
626
+ total_tests = 0
627
+ all_ok = True
628
+
629
+ print(f"\n {'Category':<28} {'Score':<10} {'Target':<10} {'Status':<8}")
630
+ print(f" {'─'*28} {'─'*10} {'─'*10} {'─'*8}")
631
+
632
+ for key in ["direct_recall", "generalization", "cross_domain",
633
+ "negation", "general", "hallucination"]:
634
+ p, t = results[key]
635
+ target = targets[key]
636
+ ok = p >= target
637
+ status_str = "PASS" if ok else "FAIL"
638
+ if not ok:
639
+ all_ok = False
640
+ total_passed += p
641
+ total_tests += t
642
+ print(f" {labels[key]:<28} {p}/{t:<10} {target}/{t:<10} {status_str:<8}")
643
+
644
+ print(f"\n {'TOTAL':<28} {total_passed}/{total_tests}")
645
+ print(f"\n Model: {s.get('model_key')}")
646
+ print(f" Mamba: {s.get('mamba_architecture', False)}")
647
+ print(f" Training: {s.get('total_steps', 0)} steps, loss {s.get('last_loss', 0):.4f}, {train_time:.0f}s")
648
+ n_reg = 21 # regularization pairs count
649
+ print(f" Training pairs: {len(TRAINING_PAIRS)} ({len(TRAINING_PAIRS) - n_reg} novel + {n_reg} regularization)")
650
+
651
+ # ── Verdict ───────────────────────────────────────────────
652
+ core_ok = (results["direct_recall"][0] >= targets["direct_recall"] and
653
+ results["general"][0] >= targets["general"])
654
+
655
+ if core_ok and all_ok:
656
+ print(f"\n ★ FULL PASS — All categories met targets!")
657
+ rc = 0
658
+ elif core_ok:
659
+ failed = [labels[k] for k in targets if results[k][0] < targets[k]]
660
+ print(f"\n ◆ PARTIAL PASS — Core recall works. Weak areas: {', '.join(failed)}")
661
+ rc = 0
662
+ else:
663
+ print(f"\n ✗ FAIL — Core metrics not met")
664
+ rc = 1
665
+
666
+ print("=" * 70)
667
+ sys.exit(rc)
668
+
669
+
670
+ if __name__ == "__main__":
671
+ main()
tests/test_statistical_e2e.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ test_statistical_e2e.py — Statistically rigorous JIT LoRA training evaluation.
4
+
5
+ Dynamically loads real-world facts (post model training cutoff), pre-tests each
6
+ against the model to confirm it's truly unknown, trains via LoRA, then evaluates
7
+ with proper statistical analysis across multiple independent trials.
8
+
9
+ Usage:
10
+ # Ensure daemon is running with model activated
11
+ python3 test_statistical_e2e.py
12
+
13
+ # Custom options
14
+ python3 test_statistical_e2e.py --facts-file raw_facts_2026.txt --trials 3 --max-facts 80
15
+
16
+ Data source: facts are loaded from a file generated by web search (not hardcoded).
17
+ The file format is:
18
+ CATEGORY: <category>
19
+ Q: <question>
20
+ A: <answer>
21
+ KEYWORDS: <comma-separated keywords>
22
+ """
23
+
24
+ import argparse
25
+ import json
26
+ import math
27
+ import os
28
+ import random
29
+ import re
30
+ import statistics
31
+ import sys
32
+ import time
33
+ from dataclasses import dataclass, field
34
+ from pathlib import Path
35
+ from typing import Optional
36
+
37
+ import requests
38
+
39
+ # ─── Configuration ───────────────────────────────────────────────────────────
40
+
41
+ DAEMON_URL = "http://localhost:8766"
42
+ DEFAULT_FACTS_FILE = os.path.join(os.path.dirname(__file__), "raw_facts_2026.txt")
43
+ MAX_TOKENS_PRETEST = 80 # Short response — just need to check if it knows
44
+ MAX_TOKENS_POSTTEST = 100 # Enough for a factual answer
45
+ TRAIN_EPOCHS = 15
46
+ REGULARIZATION_RATIO = 0.33 # ≥33% real-world data to prevent catastrophic forgetting
47
+
48
+
49
+ # ─── Data Structures ────────────────────────────────────────────────────────
50
+
51
+ @dataclass
52
+ class Fact:
53
+ category: str
54
+ question: str
55
+ answer: str
56
+ keywords: list # minimum keywords to verify correctness
57
+ pretest_response: str = ""
58
+ pretest_known: bool = False # True if model already knows this fact
59
+
60
+ def to_training_pair(self):
61
+ return {"user": self.question, "assistant": self.answer}
62
+
63
+
64
+ @dataclass
65
+ class TrialResult:
66
+ trial_id: int
67
+ n_facts_tested: int
68
+ n_confirmed_unknown: int
69
+ n_training_pairs: int
70
+ n_regularization_pairs: int
71
+ training_steps: int
72
+ training_time_s: float
73
+ initial_loss: float
74
+ final_loss: float
75
+ # Post-training scores (count correct)
76
+ recall_correct: int
77
+ recall_total: int
78
+ general_correct: int
79
+ general_total: int
80
+ # Per-category breakdown
81
+ category_scores: dict = field(default_factory=dict)
82
+
83
+
84
+ # ─── Fact Loading ────────────────────────────────────────────────────────────
85
+
86
+ def load_facts_from_file(filepath: str) -> list:
87
+ """Parse the raw facts file into Fact objects."""
88
+ facts = []
89
+ current = {}
90
+
91
+ with open(filepath) as f:
92
+ for line in f:
93
+ line = line.strip()
94
+ if not line or line.startswith("#") or line.startswith("="):
95
+ continue
96
+
97
+ if line.startswith("CATEGORY:"):
98
+ if current.get("question"):
99
+ facts.append(Fact(
100
+ category=current.get("category", "Unknown"),
101
+ question=current["question"],
102
+ answer=current.get("answer", ""),
103
+ keywords=[k.strip().lower() for k in current.get("keywords", "").split(",") if k.strip()],
104
+ ))
105
+ current = {"category": line.split(":", 1)[1].strip()}
106
+
107
+ elif line.startswith("Q:"):
108
+ # If we have a pending fact, save it first
109
+ if current.get("question"):
110
+ facts.append(Fact(
111
+ category=current.get("category", "Unknown"),
112
+ question=current["question"],
113
+ answer=current.get("answer", ""),
114
+ keywords=[k.strip().lower() for k in current.get("keywords", "").split(",") if k.strip()],
115
+ ))
116
+ cat = current.get("category", "Unknown")
117
+ current = {"category": cat}
118
+ current["question"] = line[2:].strip()
119
+
120
+ elif line.startswith("A:"):
121
+ current["answer"] = line[2:].strip()
122
+
123
+ elif line.startswith("KEYWORDS:"):
124
+ current["keywords"] = line[9:].strip()
125
+
126
+ # Don't forget the last fact
127
+ if current.get("question"):
128
+ facts.append(Fact(
129
+ category=current.get("category", "Unknown"),
130
+ question=current["question"],
131
+ answer=current.get("answer", ""),
132
+ keywords=[k.strip().lower() for k in current.get("keywords", "").split(",") if k.strip()],
133
+ ))
134
+
135
+ return facts
136
+
137
+
138
+ # ─── General Knowledge Test Set ──────────────────────────────────────────────
139
+
140
+ GENERAL_KNOWLEDGE = [
141
+ {"question": "What is the capital of France?", "keywords": ["paris"]},
142
+ {"question": "Who wrote Romeo and Juliet?", "keywords": ["shakespeare"]},
143
+ {"question": "What is the chemical symbol for water?", "keywords": ["h2o"]},
144
+ {"question": "What planet is closest to the Sun?", "keywords": ["mercury"]},
145
+ {"question": "What year did World War II end?", "keywords": ["1945"]},
146
+ {"question": "What is the speed of light in km/s approximately?", "keywords": ["299", "km"]},
147
+ {"question": "Who painted the Mona Lisa?", "keywords": ["vinci", "leonardo"]},
148
+ {"question": "What is the largest ocean on Earth?", "keywords": ["pacific"]},
149
+ {"question": "What gas do plants absorb from the atmosphere?", "keywords": ["co2", "carbon dioxide"]},
150
+ {"question": "What is the square root of 144?", "keywords": ["12"]},
151
+ {"question": "Who developed the theory of general relativity?", "keywords": ["einstein"]},
152
+ {"question": "What is the capital of Japan?", "keywords": ["tokyo"]},
153
+ {"question": "How many chromosomes do humans have?", "keywords": ["46", "23 pairs"]},
154
+ {"question": "What element has the atomic number 1?", "keywords": ["hydrogen"]},
155
+ {"question": "Who was the first person to walk on the Moon?", "keywords": ["armstrong"]},
156
+ {"question": "What is the boiling point of water in Celsius?", "keywords": ["100"]},
157
+ {"question": "What is the capital of Australia?", "keywords": ["canberra"]},
158
+ {"question": "What year was the United Nations founded?", "keywords": ["1945"]},
159
+ {"question": "What is the chemical formula for table salt?", "keywords": ["nacl"]},
160
+ {"question": "Who wrote 1984?", "keywords": ["orwell"]},
161
+ ]
162
+
163
+ # ─── Regularization pairs (real-world Q&A to prevent forgetting) ─────────────
164
+
165
+ REGULARIZATION_PAIRS = [
166
+ {"user": "What is the capital of France?", "assistant": "The capital of France is Paris."},
167
+ {"user": "Who wrote Romeo and Juliet?", "assistant": "William Shakespeare wrote Romeo and Juliet."},
168
+ {"user": "What is the chemical symbol for water?", "assistant": "The chemical symbol for water is H2O."},
169
+ {"user": "What planet is closest to the Sun?", "assistant": "Mercury is the closest planet to the Sun."},
170
+ {"user": "What year did World War II end?", "assistant": "World War II ended in 1945."},
171
+ {"user": "Who painted the Mona Lisa?", "assistant": "Leonardo da Vinci painted the Mona Lisa."},
172
+ {"user": "What is the largest ocean on Earth?", "assistant": "The Pacific Ocean is the largest ocean on Earth."},
173
+ {"user": "What gas do plants absorb from the atmosphere?", "assistant": "Plants absorb carbon dioxide (CO2) from the atmosphere."},
174
+ {"user": "What is the square root of 144?", "assistant": "The square root of 144 is 12."},
175
+ {"user": "Who developed the theory of general relativity?", "assistant": "Albert Einstein developed the theory of general relativity."},
176
+ {"user": "What is the capital of Japan?", "assistant": "The capital of Japan is Tokyo."},
177
+ {"user": "How many chromosomes do humans have?", "assistant": "Humans have 46 chromosomes, or 23 pairs."},
178
+ {"user": "What element has the atomic number 1?", "assistant": "Hydrogen has the atomic number 1."},
179
+ {"user": "Who was the first person to walk on the Moon?", "assistant": "Neil Armstrong was the first person to walk on the Moon in 1969."},
180
+ {"user": "What is the boiling point of water in Celsius?", "assistant": "The boiling point of water is 100 degrees Celsius."},
181
+ {"user": "What is the capital of Australia?", "assistant": "The capital of Australia is Canberra."},
182
+ {"user": "What year was the United Nations founded?", "assistant": "The United Nations was founded in 1945."},
183
+ {"user": "What is the chemical formula for table salt?", "assistant": "The chemical formula for table salt is NaCl (sodium chloride)."},
184
+ {"user": "Who wrote the novel 1984?", "assistant": "George Orwell wrote the novel 1984."},
185
+ {"user": "What is the tallest mountain in the world?", "assistant": "Mount Everest is the tallest mountain in the world at 8,849 meters."},
186
+ ]
187
+
188
+
189
+ # ─── Daemon API ──────────────────────────────────────────────────────────────
190
+
191
+ def daemon_status():
192
+ r = requests.get(f"{DAEMON_URL}/status", timeout=10)
193
+ r.raise_for_status()
194
+ return r.json()
195
+
196
+
197
+ def daemon_reset(retries=3):
198
+ """Reset adapter and data buffers for a clean trial."""
199
+ for attempt in range(retries):
200
+ try:
201
+ r = requests.post(f"{DAEMON_URL}/reset", json={"clear_data": True}, timeout=60)
202
+ r.raise_for_status()
203
+ return r.json()
204
+ except Exception as e:
205
+ if attempt < retries - 1:
206
+ print(f" Reset attempt {attempt+1} failed: {e}, retrying in 5s...")
207
+ time.sleep(5)
208
+ else:
209
+ raise
210
+
211
+
212
+ def daemon_query(question: str, max_tokens: int = 100) -> str:
213
+ """Query the model and collect the full response."""
214
+ try:
215
+ r = requests.post(
216
+ f"{DAEMON_URL}/chat",
217
+ json={"messages": [{"role": "user", "content": question}],
218
+ "max_tokens": max_tokens, "stream": True},
219
+ stream=True, timeout=180,
220
+ )
221
+ r.raise_for_status()
222
+ except Exception as e:
223
+ print(f" [Query error: {e}]")
224
+ return ""
225
+
226
+ text = ""
227
+ try:
228
+ for line in r.iter_lines(decode_unicode=True):
229
+ if not line or not line.startswith("data: "):
230
+ continue
231
+ payload = line[6:].strip()
232
+ if payload == "[DONE]":
233
+ break
234
+ try:
235
+ obj = json.loads(payload)
236
+ delta = obj.get("choices", [{}])[0].get("delta", {})
237
+ content = delta.get("content", "")
238
+ # Filter out special tokens
239
+ if content and not content.startswith("<|"):
240
+ text += content
241
+ except json.JSONDecodeError:
242
+ continue
243
+ except Exception as e:
244
+ print(f" [Stream error: {e}, got so far: {text[:50]}]")
245
+
246
+ return text.strip()
247
+
248
+
249
+ def daemon_inject_and_train(training_pairs: list, epochs: int = TRAIN_EPOCHS) -> dict:
250
+ """Inject training data and run epoch-based training.
251
+
252
+ Converts {"user": ..., "assistant": ...} pairs to the daemon's expected format:
253
+ [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}]
254
+
255
+ The /train endpoint is async — it starts training in background and returns immediately.
256
+ We poll /status until training completes.
257
+ """
258
+ # Convert pair format to message format
259
+ messages = []
260
+ for pair in training_pairs:
261
+ messages.append([
262
+ {"role": "user", "content": pair["user"]},
263
+ {"role": "assistant", "content": pair["assistant"]},
264
+ ])
265
+
266
+ r = requests.post(
267
+ f"{DAEMON_URL}/train",
268
+ json={"messages": messages, "epochs": epochs},
269
+ timeout=30,
270
+ )
271
+ r.raise_for_status()
272
+ start_response = r.json()
273
+ print(f" Train started: injected={start_response.get('injected', 0)}, epochs={start_response.get('epochs', 0)}")
274
+
275
+ # Poll until training completes
276
+ poll_interval = 2
277
+ max_wait = 600 # 10 minutes max
278
+ elapsed = 0
279
+ last_steps = 0
280
+ result = {"steps": 0, "final_loss": 0, "initial_loss": 0, "epochs_completed": 0, "early_stopped": False}
281
+
282
+ while elapsed < max_wait:
283
+ time.sleep(poll_interval)
284
+ elapsed += poll_interval
285
+ try:
286
+ status = daemon_status()
287
+ current_steps = status.get("total_steps", 0)
288
+ current_loss = status.get("last_loss", 0)
289
+
290
+ if current_steps != last_steps:
291
+ last_steps = current_steps
292
+
293
+ if not status.get("training", False):
294
+ # Training finished
295
+ result["steps"] = status.get("total_steps", 0)
296
+ result["final_loss"] = status.get("last_loss", 0)
297
+ result["initial_loss"] = result.get("initial_loss", current_loss)
298
+ break
299
+
300
+ # Update initial loss from first poll
301
+ if result["initial_loss"] == 0 and current_loss > 0:
302
+ result["initial_loss"] = current_loss
303
+
304
+ if elapsed % 30 == 0:
305
+ print(f" ... training: step={current_steps}, loss={current_loss:.4f}")
306
+ except Exception as e:
307
+ print(f" [Poll error: {e}]")
308
+
309
+ return result
310
+
311
+
312
+ def daemon_set_auto_train(enabled: bool):
313
+ """Enable/disable auto_train on the daemon."""
314
+ try:
315
+ r = requests.put(
316
+ f"{DAEMON_URL}/config",
317
+ json={"auto_train": enabled},
318
+ timeout=10,
319
+ )
320
+ r.raise_for_status()
321
+ except Exception as e:
322
+ print(f" [Warning: could not set auto_train={enabled}: {e}]")
323
+
324
+
325
+ # ─── Evaluation Logic ────────────────────────────────────────────────────────
326
+
327
+ def normalize_unicode(text: str) -> str:
328
+ """Normalize Unicode subscripts/superscripts to ASCII equivalents."""
329
+ import unicodedata
330
+ # Common subscript/superscript replacements
331
+ replacements = {
332
+ '₂': '2', '₃': '3', '₄': '4', '₅': '5', '₆': '6',
333
+ '₀': '0', '₁': '1', '₇': '7', '₈': '8', '₉': '9',
334
+ '²': '2', '³': '3', '⁴': '4', '⁵': '5', '⁶': '6',
335
+ '⁰': '0', '¹': '1', '⁷': '7', '⁸': '8', '⁹': '9',
336
+ }
337
+ for old, new in replacements.items():
338
+ text = text.replace(old, new)
339
+ return text
340
+
341
+
342
+ def check_keywords(response: str, keywords: list, min_matches: int = 2) -> bool:
343
+ """Check if response contains enough of the expected keywords.
344
+
345
+ Requires at least `min_matches` keywords to match to avoid false positives
346
+ from base models that hallucinate topic-relevant but factually wrong responses.
347
+ For short keyword lists (<=2), requires all to match.
348
+ """
349
+ if not keywords:
350
+ return False
351
+ response_lower = normalize_unicode(response.lower())
352
+ matches = sum(1 for kw in keywords if kw in response_lower)
353
+ required = min(min_matches, len(keywords)) # Don't require more than we have
354
+ return matches >= required
355
+
356
+
357
+ def pretest_facts(facts: list) -> tuple:
358
+ """Pre-test all facts against the model. Return (unknown, known) split."""
359
+ unknown = []
360
+ known = []
361
+
362
+ print(f"\n Pre-testing {len(facts)} facts against model...")
363
+ for i, fact in enumerate(facts):
364
+ response = daemon_query(fact.question, max_tokens=MAX_TOKENS_PRETEST)
365
+ fact.pretest_response = response
366
+ fact.pretest_known = check_keywords(response, fact.keywords)
367
+
368
+ status = "KNOWN" if fact.pretest_known else "unknown"
369
+ if (i + 1) % 10 == 0 or fact.pretest_known:
370
+ print(f" [{i+1}/{len(facts)}] {status}: {fact.question[:60]}...")
371
+
372
+ if fact.pretest_known:
373
+ known.append(fact)
374
+ else:
375
+ unknown.append(fact)
376
+
377
+ print(f" Pre-test complete: {len(unknown)} unknown, {len(known)} already known")
378
+ return unknown, known
379
+
380
+
381
+ def evaluate_recall(facts: list) -> list:
382
+ """Post-training: test recall of each fact. Returns list of (fact, correct, response)."""
383
+ results = []
384
+ for i, fact in enumerate(facts):
385
+ response = daemon_query(fact.question, max_tokens=MAX_TOKENS_POSTTEST)
386
+ correct = check_keywords(response, fact.keywords)
387
+ results.append((fact, correct, response))
388
+ if (i + 1) % 10 == 0:
389
+ print(f" [{i+1}/{len(facts)}] recall testing...")
390
+ return results
391
+
392
+
393
+ def evaluate_general_knowledge() -> list:
394
+ """Test general knowledge preservation."""
395
+ results = []
396
+ for item in GENERAL_KNOWLEDGE:
397
+ response = daemon_query(item["question"], max_tokens=100)
398
+ correct = check_keywords(response, item["keywords"])
399
+ results.append((item, correct, response))
400
+ return results
401
+
402
+
403
+ # ─── Statistics ──────────────────────────────────────────────────────────────
404
+
405
+ def clopper_pearson(k: int, n: int, alpha: float = 0.05) -> tuple:
406
+ """Clopper-Pearson exact binomial confidence interval."""
407
+ if n == 0:
408
+ return (0.0, 0.0)
409
+ from scipy import stats as scipy_stats
410
+ lo = scipy_stats.beta.ppf(alpha / 2, k, n - k + 1) if k > 0 else 0.0
411
+ hi = scipy_stats.beta.ppf(1 - alpha / 2, k + 1, n - k) if k < n else 1.0
412
+ return (lo, hi)
413
+
414
+
415
+ def wilson_interval(k: int, n: int, z: float = 1.96) -> tuple:
416
+ """Wilson score confidence interval (no scipy needed)."""
417
+ if n == 0:
418
+ return (0.0, 0.0)
419
+ p_hat = k / n
420
+ denom = 1 + z**2 / n
421
+ center = (p_hat + z**2 / (2 * n)) / denom
422
+ margin = z * math.sqrt((p_hat * (1 - p_hat) + z**2 / (4 * n)) / n) / denom
423
+ return (max(0.0, center - margin), min(1.0, center + margin))
424
+
425
+
426
+ # ─── Single Trial ────────────────────────────────────────────────────────────
427
+
428
+ def run_trial(facts: list, trial_id: int, epochs: int = TRAIN_EPOCHS) -> TrialResult:
429
+ """Run a single trial: reset → pre-test → train → evaluate."""
430
+ print(f"\n{'='*70}")
431
+ print(f" TRIAL {trial_id}")
432
+ print(f"{'='*70}")
433
+
434
+ # 1. Reset adapter for clean slate
435
+ print(" Resetting adapter and data buffers...")
436
+ daemon_reset()
437
+ time.sleep(2)
438
+
439
+ # 2. Pre-test: confirm model doesn't know these facts
440
+ unknown_facts, known_facts = pretest_facts(facts)
441
+
442
+ if len(unknown_facts) < 10:
443
+ print(f" WARNING: Only {len(unknown_facts)} unknown facts — insufficient for evaluation")
444
+ # Still proceed but flag it
445
+
446
+ # 3. Generate training pairs from unknown facts
447
+ novel_pairs = [f.to_training_pair() for f in unknown_facts]
448
+
449
+ # 4. Calculate regularization needed for ≥33% ratio
450
+ n_reg_needed = max(1, int(len(novel_pairs) * REGULARIZATION_RATIO / (1 - REGULARIZATION_RATIO)))
451
+ n_reg_used = min(n_reg_needed, len(REGULARIZATION_PAIRS))
452
+ reg_pairs = REGULARIZATION_PAIRS[:n_reg_used]
453
+
454
+ all_pairs = novel_pairs + reg_pairs
455
+ random.shuffle(all_pairs)
456
+
457
+ print(f" Training data: {len(novel_pairs)} novel + {n_reg_used} regularization = {len(all_pairs)} total")
458
+ print(f" Regularization ratio: {n_reg_used / len(all_pairs) * 100:.1f}%")
459
+
460
+ # 5. Train
461
+ print(f" Training ({epochs} epochs max, early stopping enabled)...")
462
+ t0 = time.time()
463
+ train_result = daemon_inject_and_train(all_pairs, epochs=epochs)
464
+ train_time = time.time() - t0
465
+ print(f" Training complete: {train_time:.1f}s")
466
+ print(f" {json.dumps({k: train_result.get(k) for k in ['steps', 'final_loss', 'initial_loss', 'epochs_completed', 'early_stopped']}, default=str)}")
467
+
468
+ time.sleep(2) # Let model settle
469
+
470
+ # 6. Post-test: recall of unknown facts
471
+ print(f"\n Evaluating recall ({len(unknown_facts)} facts)...")
472
+ recall_results = evaluate_recall(unknown_facts)
473
+ recall_correct = sum(1 for _, c, _ in recall_results if c)
474
+
475
+ # 7. General knowledge preservation
476
+ print(f" Evaluating general knowledge ({len(GENERAL_KNOWLEDGE)} questions)...")
477
+ gen_results = evaluate_general_knowledge()
478
+ gen_correct = sum(1 for _, c, _ in gen_results if c)
479
+
480
+ # 8. Per-category breakdown
481
+ category_scores = {}
482
+ for fact, correct, _ in recall_results:
483
+ cat = fact.category
484
+ if cat not in category_scores:
485
+ category_scores[cat] = {"correct": 0, "total": 0}
486
+ category_scores[cat]["total"] += 1
487
+ if correct:
488
+ category_scores[cat]["correct"] += 1
489
+
490
+ result = TrialResult(
491
+ trial_id=trial_id,
492
+ n_facts_tested=len(facts),
493
+ n_confirmed_unknown=len(unknown_facts),
494
+ n_training_pairs=len(all_pairs),
495
+ n_regularization_pairs=n_reg_used,
496
+ training_steps=train_result.get("steps", 0),
497
+ training_time_s=train_time,
498
+ initial_loss=train_result.get("initial_loss", 0),
499
+ final_loss=train_result.get("final_loss", 0),
500
+ recall_correct=recall_correct,
501
+ recall_total=len(unknown_facts),
502
+ general_correct=gen_correct,
503
+ general_total=len(GENERAL_KNOWLEDGE),
504
+ category_scores=category_scores,
505
+ )
506
+
507
+ # Print trial summary
508
+ print(f"\n Trial {trial_id} Results:")
509
+ print(f" Recall: {recall_correct}/{len(unknown_facts)} ({recall_correct/max(1,len(unknown_facts))*100:.1f}%)")
510
+ print(f" General Knowledge: {gen_correct}/{len(GENERAL_KNOWLEDGE)} ({gen_correct/len(GENERAL_KNOWLEDGE)*100:.1f}%)")
511
+ print(f" Training: {result.training_steps} steps, {train_time:.1f}s, loss {result.initial_loss:.3f} → {result.final_loss:.3f}")
512
+
513
+ # Print failures for debugging
514
+ failures = [(f, r) for f, c, r in recall_results if not c]
515
+ if failures:
516
+ print(f"\n Failed recalls ({len(failures)}):")
517
+ for fact, resp in failures[:10]:
518
+ print(f" Q: {fact.question[:70]}")
519
+ print(f" Expected keywords: {fact.keywords}")
520
+ print(f" Got: {resp[:100]}")
521
+ print()
522
+
523
+ gen_failures = [(item, r) for item, c, r in gen_results if not c]
524
+ if gen_failures:
525
+ print(f" General knowledge failures ({len(gen_failures)}):")
526
+ for item, resp in gen_failures:
527
+ print(f" Q: {item['question']}")
528
+ print(f" Expected: {item['keywords']}")
529
+ print(f" Got: {resp[:100]}")
530
+
531
+ return result
532
+
533
+
534
+ def run_trial_prefiltered(unknown_facts: list, trial_id: int, epochs: int = TRAIN_EPOCHS) -> TrialResult:
535
+ """Run a trial with pre-filtered facts (already confirmed unknown). Skips pre-testing."""
536
+ print(f"\n{'='*70}")
537
+ print(f" TRIAL {trial_id}")
538
+ print(f"{'='*70}")
539
+
540
+ # 1. Reset adapter for clean slate
541
+ print(" Resetting adapter and data buffers...")
542
+ daemon_reset()
543
+ time.sleep(2)
544
+
545
+ # 2. Generate training pairs from unknown facts
546
+ novel_pairs = [f.to_training_pair() for f in unknown_facts]
547
+
548
+ # 3. Calculate regularization needed for ≥33% ratio
549
+ n_reg_needed = max(1, int(len(novel_pairs) * REGULARIZATION_RATIO / (1 - REGULARIZATION_RATIO)))
550
+ n_reg_used = min(n_reg_needed, len(REGULARIZATION_PAIRS))
551
+ reg_pairs = REGULARIZATION_PAIRS[:n_reg_used]
552
+
553
+ all_pairs = novel_pairs + reg_pairs
554
+ random.shuffle(all_pairs)
555
+
556
+ print(f" Training data: {len(novel_pairs)} novel + {n_reg_used} regularization = {len(all_pairs)} total")
557
+ print(f" Regularization ratio: {n_reg_used / len(all_pairs) * 100:.1f}%")
558
+
559
+ # 4. Train (auto_train stays off — we train explicitly via /train)
560
+ print(f" Training ({epochs} epochs max, early stopping enabled)...")
561
+ t0 = time.time()
562
+ train_result = daemon_inject_and_train(all_pairs, epochs=epochs)
563
+ train_time = time.time() - t0
564
+ print(f" Training complete: {train_time:.1f}s")
565
+ print(f" {json.dumps({k: train_result.get(k) for k in ['steps', 'final_loss', 'initial_loss', 'epochs_completed', 'early_stopped']}, default=str)}")
566
+
567
+ time.sleep(2) # Let model settle
568
+
569
+ # 5. Post-test: recall of unknown facts (auto_train disabled to avoid contamination)
570
+ daemon_set_auto_train(False)
571
+ print(f"\n Evaluating recall ({len(unknown_facts)} facts)...")
572
+ recall_results = evaluate_recall(unknown_facts)
573
+ recall_correct = sum(1 for _, c, _ in recall_results if c)
574
+
575
+ # 6. General knowledge preservation
576
+ print(f" Evaluating general knowledge ({len(GENERAL_KNOWLEDGE)} questions)...")
577
+ gen_results = evaluate_general_knowledge()
578
+ gen_correct = sum(1 for _, c, _ in gen_results if c)
579
+
580
+ # 7. Per-category breakdown
581
+ category_scores = {}
582
+ for fact, correct, _ in recall_results:
583
+ cat = fact.category
584
+ if cat not in category_scores:
585
+ category_scores[cat] = {"correct": 0, "total": 0}
586
+ category_scores[cat]["total"] += 1
587
+ if correct:
588
+ category_scores[cat]["correct"] += 1
589
+
590
+ result = TrialResult(
591
+ trial_id=trial_id,
592
+ n_facts_tested=len(unknown_facts),
593
+ n_confirmed_unknown=len(unknown_facts),
594
+ n_training_pairs=len(all_pairs),
595
+ n_regularization_pairs=n_reg_used,
596
+ training_steps=train_result.get("steps", 0),
597
+ training_time_s=train_time,
598
+ initial_loss=train_result.get("initial_loss", 0),
599
+ final_loss=train_result.get("final_loss", 0),
600
+ recall_correct=recall_correct,
601
+ recall_total=len(unknown_facts),
602
+ general_correct=gen_correct,
603
+ general_total=len(GENERAL_KNOWLEDGE),
604
+ category_scores=category_scores,
605
+ )
606
+
607
+ # Print trial summary
608
+ print(f"\n Trial {trial_id} Results:")
609
+ print(f" Recall: {recall_correct}/{len(unknown_facts)} ({recall_correct/max(1,len(unknown_facts))*100:.1f}%)")
610
+ print(f" General Knowledge: {gen_correct}/{len(GENERAL_KNOWLEDGE)} ({gen_correct/len(GENERAL_KNOWLEDGE)*100:.1f}%)")
611
+ print(f" Training: {result.training_steps} steps, {train_time:.1f}s, loss {result.initial_loss:.3f} → {result.final_loss:.3f}")
612
+
613
+ # Print failures for debugging
614
+ failures = [(f, r) for f, c, r in recall_results if not c]
615
+ if failures:
616
+ print(f"\n Failed recalls ({len(failures)}):")
617
+ for fact, resp in failures[:10]:
618
+ print(f" Q: {fact.question[:70]}")
619
+ print(f" Expected keywords: {fact.keywords}")
620
+ print(f" Got: {resp[:100]}")
621
+ print()
622
+
623
+ gen_failures = [(item, r) for item, c, r in gen_results if not c]
624
+ if gen_failures:
625
+ print(f" General knowledge failures ({len(gen_failures)}):")
626
+ for item, resp in gen_failures:
627
+ print(f" Q: {item['question']}")
628
+ print(f" Expected: {item['keywords']}")
629
+ print(f" Got: {resp[:100]}")
630
+
631
+ return result
632
+
633
+
634
+ # ─── Multi-Trial Analysis ────────────────────────────────────────────────────
635
+
636
+ def run_evaluation(facts: list, n_trials: int = 3, epochs: int = TRAIN_EPOCHS):
637
+ """Run multiple independent trials and report aggregate statistics."""
638
+ print(f"\n{'#'*70}")
639
+ print(f" STATISTICAL JIT LoRA EVALUATION")
640
+ print(f" Model: {daemon_status()['model_key']}")
641
+ print(f" Facts available: {len(facts)}")
642
+ print(f" Trials: {n_trials}")
643
+ print(f" Epochs: {epochs} (with early stopping)")
644
+ print(f" Regularization target: {REGULARIZATION_RATIO*100:.0f}%")
645
+ print(f"{'#'*70}")
646
+
647
+ # Disable auto_train during pre-testing to avoid contamination
648
+ daemon_set_auto_train(False)
649
+
650
+ # Pre-test once (base model is the same for all trials after reset)
651
+ print(f"\n === Pre-testing all {len(facts)} facts (one-time baseline) ===")
652
+ daemon_reset()
653
+ time.sleep(2)
654
+ unknown_facts, known_facts = pretest_facts(facts)
655
+ print(f"\n Baseline: {len(unknown_facts)} confirmed unknown, {len(known_facts)} already known")
656
+ print(f" Will train on {len(unknown_facts)} unknown facts across {n_trials} trials\n")
657
+
658
+ if len(unknown_facts) < 10:
659
+ print(" ERROR: Too few unknown facts for meaningful evaluation.")
660
+ print(" The model already knows most of the dataset.")
661
+ return None
662
+
663
+ results = []
664
+ for trial in range(1, n_trials + 1):
665
+ # Shuffle facts for each trial to avoid ordering effects
666
+ trial_unknown = unknown_facts.copy()
667
+ random.shuffle(trial_unknown)
668
+ result = run_trial_prefiltered(trial_unknown, trial, epochs)
669
+ results.append(result)
670
+
671
+ # ─── Aggregate Statistics ────────────────────────────────────────────
672
+ print(f"\n{'='*70}")
673
+ print(f" AGGREGATE RESULTS ({n_trials} trials)")
674
+ print(f"{'='*70}")
675
+
676
+ # Recall rates across trials
677
+ recall_rates = [r.recall_correct / max(1, r.recall_total) for r in results]
678
+ general_rates = [r.general_correct / max(1, r.general_total) for r in results]
679
+ training_times = [r.training_time_s for r in results]
680
+ training_steps_list = [r.training_steps for r in results]
681
+ n_unknown_list = [r.n_confirmed_unknown for r in results]
682
+
683
+ # Pooled counts for CI calculation
684
+ pooled_recall_k = sum(r.recall_correct for r in results)
685
+ pooled_recall_n = sum(r.recall_total for r in results)
686
+ pooled_gen_k = sum(r.general_correct for r in results)
687
+ pooled_gen_n = sum(r.general_total for r in results)
688
+
689
+ recall_ci = wilson_interval(pooled_recall_k, pooled_recall_n)
690
+ general_ci = wilson_interval(pooled_gen_k, pooled_gen_n)
691
+
692
+ print(f"\n Confirmed unknown facts per trial: {n_unknown_list}")
693
+ print(f" (facts the model verified it did NOT know before training)")
694
+
695
+ print(f"\n ┌────────────���────────────────────────────────────────────────────┐")
696
+ print(f" │ RECALL (post-training) │")
697
+ print(f" │ Pooled: {pooled_recall_k}/{pooled_recall_n} ({pooled_recall_k/max(1,pooled_recall_n)*100:.1f}%) │")
698
+ print(f" │ Per-trial rates: {[f'{r:.1%}' for r in recall_rates]}")
699
+ if n_trials > 1 and len(recall_rates) > 1:
700
+ print(f" │ Mean ± StdDev: {statistics.mean(recall_rates):.1%} ± {statistics.stdev(recall_rates):.1%}")
701
+ print(f" │ 95% CI (Wilson): [{recall_ci[0]:.1%}, {recall_ci[1]:.1%}]")
702
+ print(f" │ │")
703
+ print(f" │ GENERAL KNOWLEDGE (preservation) │")
704
+ print(f" │ Pooled: {pooled_gen_k}/{pooled_gen_n} ({pooled_gen_k/max(1,pooled_gen_n)*100:.1f}%) │")
705
+ print(f" │ Per-trial rates: {[f'{r:.1%}' for r in general_rates]}")
706
+ if n_trials > 1 and len(general_rates) > 1:
707
+ print(f" │ Mean ± StdDev: {statistics.mean(general_rates):.1%} ± {statistics.stdev(general_rates):.1%}")
708
+ print(f" │ 95% CI (Wilson): [{general_ci[0]:.1%}, {general_ci[1]:.1%}]")
709
+ print(f" │ │")
710
+ print(f" │ TRAINING │")
711
+ print(f" │ Mean time: {statistics.mean(training_times):.1f}s ± {statistics.stdev(training_times) if len(training_times) > 1 else 0:.1f}s")
712
+ print(f" │ Mean steps: {statistics.mean(training_steps_list):.0f}")
713
+ print(f" └─────────────────────────────────────────────────────────────────┘")
714
+
715
+ # Per-category aggregation
716
+ all_categories = set()
717
+ for r in results:
718
+ all_categories.update(r.category_scores.keys())
719
+
720
+ print(f"\n Per-Category Recall (pooled across trials):")
721
+ print(f" {'Category':<25} {'Correct':>8} {'Total':>8} {'Rate':>8} {'95% CI':>16}")
722
+ print(f" {'-'*25} {'-'*8} {'-'*8} {'-'*8} {'-'*16}")
723
+
724
+ for cat in sorted(all_categories):
725
+ cat_k = sum(r.category_scores.get(cat, {}).get("correct", 0) for r in results)
726
+ cat_n = sum(r.category_scores.get(cat, {}).get("total", 0) for r in results)
727
+ if cat_n > 0:
728
+ cat_ci = wilson_interval(cat_k, cat_n)
729
+ print(f" {cat:<25} {cat_k:>8} {cat_n:>8} {cat_k/cat_n:>8.1%} [{cat_ci[0]:.1%}, {cat_ci[1]:.1%}]")
730
+
731
+ # Save results to JSON
732
+ output = {
733
+ "model": daemon_status().get("model_key", "unknown"),
734
+ "n_trials": n_trials,
735
+ "epochs": epochs,
736
+ "regularization_ratio": REGULARIZATION_RATIO,
737
+ "aggregate": {
738
+ "recall": {
739
+ "pooled_correct": pooled_recall_k,
740
+ "pooled_total": pooled_recall_n,
741
+ "pooled_rate": pooled_recall_k / max(1, pooled_recall_n),
742
+ "per_trial_rates": recall_rates,
743
+ "mean": statistics.mean(recall_rates),
744
+ "stdev": statistics.stdev(recall_rates) if len(recall_rates) > 1 else 0,
745
+ "ci_95_lower": recall_ci[0],
746
+ "ci_95_upper": recall_ci[1],
747
+ },
748
+ "general_knowledge": {
749
+ "pooled_correct": pooled_gen_k,
750
+ "pooled_total": pooled_gen_n,
751
+ "pooled_rate": pooled_gen_k / max(1, pooled_gen_n),
752
+ "per_trial_rates": general_rates,
753
+ "mean": statistics.mean(general_rates),
754
+ "stdev": statistics.stdev(general_rates) if len(general_rates) > 1 else 0,
755
+ "ci_95_lower": general_ci[0],
756
+ "ci_95_upper": general_ci[1],
757
+ },
758
+ "training": {
759
+ "mean_time_s": statistics.mean(training_times),
760
+ "stdev_time_s": statistics.stdev(training_times) if len(training_times) > 1 else 0,
761
+ "mean_steps": statistics.mean(training_steps_list),
762
+ "per_trial_times": training_times,
763
+ },
764
+ },
765
+ "trials": [
766
+ {
767
+ "trial_id": r.trial_id,
768
+ "n_confirmed_unknown": r.n_confirmed_unknown,
769
+ "n_training_pairs": r.n_training_pairs,
770
+ "training_steps": r.training_steps,
771
+ "training_time_s": r.training_time_s,
772
+ "initial_loss": r.initial_loss,
773
+ "final_loss": r.final_loss,
774
+ "recall_correct": r.recall_correct,
775
+ "recall_total": r.recall_total,
776
+ "recall_rate": r.recall_correct / max(1, r.recall_total),
777
+ "general_correct": r.general_correct,
778
+ "general_total": r.general_total,
779
+ "general_rate": r.general_correct / max(1, r.general_total),
780
+ "category_scores": r.category_scores,
781
+ }
782
+ for r in results
783
+ ],
784
+ }
785
+
786
+ results_path = os.path.join(os.path.dirname(__file__), "evaluation_results.json")
787
+ with open(results_path, "w") as f:
788
+ json.dump(output, f, indent=2)
789
+ print(f"\n Results saved to: {results_path}")
790
+
791
+ return output
792
+
793
+
794
+ # ─── Main ────────────────────────────────────────────────────────────────────
795
+
796
+ def main():
797
+ parser = argparse.ArgumentParser(description="Statistical JIT LoRA evaluation")
798
+ parser.add_argument("--facts-file", default=DEFAULT_FACTS_FILE,
799
+ help="Path to raw facts file (default: raw_facts_2026.txt)")
800
+ parser.add_argument("--trials", type=int, default=3,
801
+ help="Number of independent trials (default: 3)")
802
+ parser.add_argument("--max-facts", type=int, default=0,
803
+ help="Max facts to use (0 = all, default: 0)")
804
+ parser.add_argument("--epochs", type=int, default=TRAIN_EPOCHS,
805
+ help=f"Training epochs per trial (default: {TRAIN_EPOCHS})")
806
+ parser.add_argument("--seed", type=int, default=42,
807
+ help="Random seed for reproducibility (default: 42)")
808
+ args = parser.parse_args()
809
+
810
+ random.seed(args.seed)
811
+
812
+ # Verify daemon is running
813
+ try:
814
+ status = daemon_status()
815
+ if not status.get("active"):
816
+ print("ERROR: Daemon not active. Call /activate first.")
817
+ sys.exit(1)
818
+ print(f"Daemon OK: {status['model_key']}, {status.get('trainable_params', '?')} trainable params")
819
+ except Exception as e:
820
+ print(f"ERROR: Cannot reach daemon at {DAEMON_URL}: {e}")
821
+ sys.exit(1)
822
+
823
+ # Load facts
824
+ if not os.path.exists(args.facts_file):
825
+ print(f"ERROR: Facts file not found: {args.facts_file}")
826
+ print("Generate it first by running the web scraper or provide a path.")
827
+ sys.exit(1)
828
+
829
+ facts = load_facts_from_file(args.facts_file)
830
+ print(f"Loaded {len(facts)} facts from {args.facts_file}")
831
+
832
+ # Deduplicate by question
833
+ seen = set()
834
+ unique_facts = []
835
+ for f in facts:
836
+ key = f.question.lower().strip()
837
+ if key not in seen:
838
+ seen.add(key)
839
+ unique_facts.append(f)
840
+ facts = unique_facts
841
+ print(f"After dedup: {len(facts)} unique facts")
842
+
843
+ # Category distribution
844
+ cats = {}
845
+ for f in facts:
846
+ cats[f.category] = cats.get(f.category, 0) + 1
847
+ print(f"Categories: {dict(sorted(cats.items()))}")
848
+
849
+ if args.max_facts > 0 and args.max_facts < len(facts):
850
+ # Sample proportionally from each category
851
+ facts = random.sample(facts, args.max_facts)
852
+ print(f"Sampled down to {len(facts)} facts")
853
+
854
+ # Run evaluation
855
+ output = run_evaluation(facts, n_trials=args.trials, epochs=args.epochs)
856
+
857
+ # Final verdict
858
+ recall_rate = output["aggregate"]["recall"]["mean"]
859
+ gen_rate = output["aggregate"]["general_knowledge"]["mean"]
860
+
861
+ print(f"\n{'='*70}")
862
+ if recall_rate >= 0.50 and gen_rate >= 0.80:
863
+ print(f" ✓ EVALUATION PASSED")
864
+ print(f" Recall: {recall_rate:.1%} (≥50% threshold)")
865
+ print(f" General Knowledge: {gen_rate:.1%} (≥80% threshold)")
866
+ else:
867
+ print(f" ✗ EVALUATION BELOW THRESHOLD")
868
+ print(f" Recall: {recall_rate:.1%} {'✓' if recall_rate >= 0.50 else '✗ (<50%)'}")
869
+ print(f" General Knowledge: {gen_rate:.1%} {'✓' if gen_rate >= 0.80 else '✗ (<80%)'}")
870
+ print(f"{'='*70}")
871
+
872
+
873
+ if __name__ == "__main__":
874
+ main()