| # Overflow Probe |
|
|
| A binary MLP probe that detects **token overflow** in soft-compressed document representations [XRAG](https://arxiv.org/abs/2405.13792). Token overflow occurs when a document's information content exceeds the capacity of the compressed token budget, leading to degraded downstream QA performance. |
|
|
| ## How It Works |
|
|
| The probe takes a concatenation of two 4096-dim vectors: |
|
|
| | Component | Description | |
| |-----------|-------------| |
| | `postproj` | Compressed context embedding after projection layer | |
| | `postproj_q` | Query embedding after projection | |
|
|
| Input shape: `(n, 8192)` — the concatenation `[postproj; postproj_q]`. |
|
|
| Output: probability that the compressed representation has **overflowed** (i.e., lost critical information). |
|
|
| ## Installation |
|
|
| ```bash |
| pip install torch huggingface_hub scikit-learn |
| ``` |
|
|
| ## Usage |
|
|
| ### 1. Get the class definition |
|
|
| The model requires the `MLPProbeTorch` class to load. Grab it from this repo: |
|
|
| ```python |
| from huggingface_hub import hf_hub_download |
| import importlib.util, sys |
| |
| path = hf_hub_download("wexumin/overflow_probe_xrag_full", "mlp_probe.py") |
| spec = importlib.util.spec_from_file_location("mlp_probe", path) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| MLPProbeTorch = mod.MLPProbeTorch |
| ``` |
|
|
| ### 2. Load the model |
|
|
| ```python |
| model = MLPProbeTorch.from_pretrained("wexumin/overflow_probe_xrag_full") |
| ``` |
|
|
| ### 3. Run inference |
|
|
| ```python |
| import numpy as np |
| |
| # postproj: compressed doc embedding (4096-dim) |
| # postproj_q: query embedding (4096-dim) |
| x = np.concatenate([postproj, postproj_q], axis=-1) # (n, 8192) |
| |
| probs = model.predict_proba(x) # (n, 2) — [:, 1] is overflow probability |
| preds = model.predict(x) # (n,) — binary 0/1 |
| ``` |
|
|
| ## Training Data |
|
|
| Trained on query–document pairs from three datasets, with supporting context limited to ±1 sentence around the gold span: |
|
|
| - **SQuAD** — extractive QA over Wikipedia paragraphs |
| - **TriviaQA** — trivia questions with web/Wikipedia evidence |
| - **HotpotQA** — multi-hop reasoning over Wikipedia |
|
|
| ## Architecture |
|
|
| ``` |
| Dropout(0.1) |
| → Linear(8192, 1024) |
| → SiLU |
| → BatchNorm1d(1024) |
| → Dropout(0.1) |
| → Linear(1024, 1) |
| ``` |
|
|
| Trained with BCE loss, L2 regularization, Adam optimizer, and early stopping on validation AUC. |
|
|
| ## Citation |
|
|
| ```bibtex |
| @inproceedings{belikova-etal-2026-detecting, |
| title = "Detecting Overflow in Compressed Token Representations for Retrieval-Augmented Generation", |
| author = "Belikova, Julia and Rozhevskii, Danila and Svirin, Dennis and Polev, Konstantin and Panchenko, Alexander", |
| editor = "Baez Santamaria, Selene and Somayajula, Sai Ashish and Yamaguchi, Atsuki", |
| booktitle = "Proceedings of the 19th Conference of the {E}uropean Chapter of the {A}ssociation for {C}omputational {L}inguistics (Volume 4: Student Research Workshop)", |
| month = mar, |
| year = "2026", |
| address = "Rabat, Morocco", |
| publisher = "Association for Computational Linguistics", |
| url = "https://aclanthology.org/2026.eacl-srw.59/", |
| pages = "797--810", |
| ISBN = "979-8-89176-383-8" |
| } |
| ``` |