MiA-Emb-4B / README.md
MindscapeRAG's picture
Update README.md
e71673f verified
---
license: apache-2.0
language:
- en
- zh
base_model:
- Qwen/Qwen3-Embedding-4B
tags:
- embedding
- retriever
- RAG
---
# Mindscape-Aware RAG (MiA-RAG)
[![Paper](https://img.shields.io/badge/Paper-arXiv%3A2512.17220-red)](https://arxiv.org/pdf/2512.17220)
[![Model](https://img.shields.io/badge/HuggingFace-MiA--Emb--4B-yellow)](https://huggingface.co/MindscapeRAG/MiA-Emb-4B)
This repository provides the inference implementation for **MiA-Emb (Mindscape-Aware Embedding)**, the retriever component in the **MiA-RAG** framework.
**MiA-RAG** introduces explicit **global context awareness** via a **Mindscape**—a document-level semantic scaffold constructed by **hierarchical summarization**. By conditioning **both retrieval and generation** on the same Mindscape, MiA-RAG enables globally grounded retrieval and more coherent long-context reasoning.
---
## ✨ Key Features
- **Mindscape as Global Semantic Scaffold**
Builds a Mindscape through **hierarchical bottom-up summarization** (chunk summaries → global summary) and uses it as persistent global memory.
- **Mindscape-Aware Capabilities**
Supports the three core benefits for long-context understanding:
- **Enriched Understanding**: fill in missing context and resolve underspecified meanings
- **Selective Retrieval**: bias retrieval toward the active topic’s semantic frame
- **Integrative Reasoning**: interpret retrieved evidence within a coherent global context
- **Dual-Granularity Retrieval**
- **Chunk Retrieval** for narrative passages (standard RAG)
- **Node Retrieval** for knowledge graph entities (GraphRAG-style)
- **State-of-the-Art Retrieval Performance**
Strong results on long-context benchmarks such as NarrativeQA and DetectiveQA, outperforming strong baselines including Qwen3-Embedding and [SitEmb](https://huggingface.co/SituatedEmbedding/SitEmb-v1.5-Qwen3).
---
## 🚀 Usage
### Installation
```bash
pip install torch transformers>=4.53.0
```
---
### 1) Initialization
> MiA-Emb-4B is initialized from **`Qwen3-Embedding-4B`**.
```python
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
# Inference Parameters
residual = True # Enable residual connection logic
residual_factor = 0.5 # Balance between local and global
node_delimiter = "<|repo_name|>" # Special token for Node tasks
# Load Tokenizer (base)
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-Embedding-4B",
trust_remote_code=True,
padding_side="left"
)
# Load Model
model = AutoModel.from_pretrained(
"MindscapeRAG/MiA-Emb-4B",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map={"": 0}
)
```
---
### 2) Chunk Retrieval
Use this mode to retrieve narrative text chunks. A **Global Summary** is injected into the prompt as the “Mindscape”.
```python
def get_query_prompt(query, summary="", residual=False):
"""Construct input prompt with global summary (Eq. 5 in paper)."""
task_desc = "Given a search query with the book's summary, retrieve relevant chunks or helpful entities summaries from the given context that answer the query"
summary_prefix = "\n\nHere is the summary providing possibly useful global information. Please encode the query based on the summary:\n"
# Insert PAD token to capture residual embedding before the summary
middle_token = tokenizer.pad_token if residual else ""
return (
f"Instruct: {task_desc}\n"
f"Query: {query}{middle_token}{summary_prefix}{summary}{node_delimiter}"
)
def encode_chunk(texts, is_query=False, residual=False):
batch = tokenizer(
texts,
max_length=4096,
padding=True,
truncation=True,
return_tensors="pt"
).to(model.device)
outputs = model(**batch)
# 1) Main Embedding (Last Token)
emb_main = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
# 2) Residual Embedding (PAD Token)
emb_res = None
if residual and is_query:
emb_res = extract_residual_token(outputs, batch, tokenizer.pad_token_id)
emb_main = F.normalize(emb_main, p=2, dim=-1)
emb_res = F.normalize(emb_res, p=2, dim=-1) if emb_res is not None else None
return emb_main, emb_res
# --- Example ---
query = "Who is the protagonist?"
global_summ = "A summary of the entire book..."
chunk = "Harry looked at the scar on his forehead."
# Encode
q_emb, q_res = encode_chunk(
[get_query_prompt(query, global_summ, residual=True)],
is_query=True,
residual=True
)
c_emb, _ = encode_chunk([chunk], is_query=False)
# Score Fusion
score = q_emb @ c_emb.T
if q_res is not None:
score = (1 - residual_factor) * score + residual_factor * (q_res @ c_emb.T)
print(f"Chunk Similarity: {score.item():.4f}")
```
---
### 3) Node Retrieval
MiA-Emb can retrieve knowledge graph entities (**Nodes**). This mode extracts embeddings from the `<|repo_name|>` token position.
**Candidate format:**
`Entity Name : Entity Description`
Example:
`Mary Campbell Smith : Mary Campbell Smith is mentioned as the translator...`
```python
def encode_node_query(texts, residual=True, node_delimiter="<|repo_name|>"):
batch = tokenizer(texts, padding=True, return_tensors="pt").to(model.device)
outputs = model(**batch)
# 1) Node Main Embedding: extract from <|repo_name|> position
node_id = tokenizer.encode(node_delimiter, add_special_tokens=False)[0]
q_emb_node = extract_specific_token(outputs, batch, node_id)
# 2) Residual Embedding: extract from [PAD] position
q_emb_res = extract_residual_token(outputs, batch, tokenizer.pad_token_id) if residual else None
q_emb_node = F.normalize(q_emb_node, p=2, dim=-1)
q_emb_res = F.normalize(q_emb_res, p=2, dim=-1) if q_emb_res is not None else None
return q_emb_node, q_emb_res
# --- Example ---
query = "Who is the protagonist?"
global_summ = "A summary of the entire book..."
# 1) Encode Query (Node Token)
q_emb_node, q_emb_res = encode_node_query(
[get_query_prompt(query, global_summ, residual=True)],
residual=True
)
# 2) Encode Entity Candidate
entity_text = "Harry Potter : The main protagonist of the series..."
n_emb, _ = encode_chunk([entity_text], is_query=False)
# 3) Score Fusion
final_score = (1 - residual_factor) * (q_emb_node @ n_emb.T)
if q_emb_res is not None:
final_score = final_score + residual_factor * (q_emb_res @ n_emb.T)
print(f"Node Similarity: {final_score.item():.4f}")
```
---
## 📜 Citation
If you find this work useful, please cite:
```bibtex
@misc{li2025mindscapeawareretrievalaugmentedgeneration,
title={Mindscape-Aware Retrieval Augmented Generation for Improved Long Context Understanding},
author={Yuqing Li and Jiangnan Li and Zheng Lin and Ziyan Zhou and Junjie Wu and Weiping Wang and Jie Zhou and Mo Yu},
year={2025},
eprint={2512.17220},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2512.17220},
}
```