# Overflow Probe A binary MLP probe that detects **token overflow** in soft-compressed document representations [XRAG](https://arxiv.org/abs/2405.13792). Token overflow occurs when a document's information content exceeds the capacity of the compressed token budget, leading to degraded downstream QA performance. ## How It Works The probe takes a concatenation of two 4096-dim vectors: | Component | Description | |-----------|-------------| | `postproj` | Compressed context embedding after projection layer | | `postproj_q` | Query embedding after projection | Input shape: `(n, 8192)` — the concatenation `[postproj; postproj_q]`. Output: probability that the compressed representation has **overflowed** (i.e., lost critical information). ## Installation ```bash pip install torch huggingface_hub scikit-learn ``` ## Usage ### 1. Get the class definition The model requires the `MLPProbeTorch` class to load. Grab it from this repo: ```python from huggingface_hub import hf_hub_download import importlib.util, sys path = hf_hub_download("wexumin/overflow_probe_xrag_full", "mlp_probe.py") spec = importlib.util.spec_from_file_location("mlp_probe", path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) MLPProbeTorch = mod.MLPProbeTorch ``` ### 2. Load the model ```python model = MLPProbeTorch.from_pretrained("wexumin/overflow_probe_xrag_full") ``` ### 3. Run inference ```python import numpy as np # postproj: compressed doc embedding (4096-dim) # postproj_q: query embedding (4096-dim) x = np.concatenate([postproj, postproj_q], axis=-1) # (n, 8192) probs = model.predict_proba(x) # (n, 2) — [:, 1] is overflow probability preds = model.predict(x) # (n,) — binary 0/1 ``` ## Training Data Trained on query–document pairs from three datasets, with supporting context limited to ±1 sentence around the gold span: - **SQuAD** — extractive QA over Wikipedia paragraphs - **TriviaQA** — trivia questions with web/Wikipedia evidence - **HotpotQA** — multi-hop reasoning over Wikipedia ## Architecture ``` Dropout(0.1) → Linear(8192, 1024) → SiLU → BatchNorm1d(1024) → Dropout(0.1) → Linear(1024, 1) ``` Trained with BCE loss, L2 regularization, Adam optimizer, and early stopping on validation AUC. ## Citation ```bibtex @inproceedings{belikova-etal-2026-detecting, title = "Detecting Overflow in Compressed Token Representations for Retrieval-Augmented Generation", author = "Belikova, Julia and Rozhevskii, Danila and Svirin, Dennis and Polev, Konstantin and Panchenko, Alexander", editor = "Baez Santamaria, Selene and Somayajula, Sai Ashish and Yamaguchi, Atsuki", booktitle = "Proceedings of the 19th Conference of the {E}uropean Chapter of the {A}ssociation for {C}omputational {L}inguistics (Volume 4: Student Research Workshop)", month = mar, year = "2026", address = "Rabat, Morocco", publisher = "Association for Computational Linguistics", url = "https://aclanthology.org/2026.eacl-srw.59/", pages = "797--810", ISBN = "979-8-89176-383-8" } ```