|
|
--- |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
tags: |
|
|
- information-retrieval |
|
|
- LLM |
|
|
- Embedding |
|
|
- text-retrieval |
|
|
- disaster-management |
|
|
|
|
|
task_categories: |
|
|
- text-retrieval |
|
|
library_name: transformers |
|
|
dataset_tags: |
|
|
- DMIR01/DMRetriever_MTT |
|
|
--- |
|
|
|
|
|
This model is trained through the approach described in [DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management](https://www.arxiv.org/abs/2510.15087). |
|
|
The associated GitHub repository is available [here](https://github.com/KaiYin97/DMRETRIEVER). |
|
|
This model has 4B parameters. |
|
|
|
|
|
## 🧠 Model Overview |
|
|
|
|
|
**DMRetriever-4B** has the following features: |
|
|
|
|
|
- Model Type: Text Embedding |
|
|
- Supported Languages: English |
|
|
- Number of Paramaters: 4B |
|
|
- Embedding Dimension: 2560 |
|
|
|
|
|
For more details, including model training, benchmark evaluation, and inference performance, please refer to our [paper](https://www.arxiv.org/abs/2510.15087), [GitHub](https://github.com/KaiYin97/DMRETRIEVER). |
|
|
|
|
|
## 📦 DMRetriever Series Model List |
|
|
|
|
|
| **Model** | **Description** | **Backbone** | **Backbone Type** | **Hidden Size** | **#Layers** | |
|
|
|:--|:--|:--|:--|:--:|:--:| |
|
|
| [DMRetriever-33M](https://huggingface.co/DMIR01/DMRetriever-33M) | Base 33M variant | MiniLM | Encoder-only | 384 | 12 | |
|
|
| [DMRetriever-33M-PT](https://huggingface.co/DMIR01/DMRetriever-33M-PT) | Pre-trained version of 33M | MiniLM | Encoder-only | 384 | 12 | |
|
|
| [DMRetriever-109M](https://huggingface.co/DMIR01/DMRetriever-109M) | Base 109M variant | BERT-base-uncased | Encoder-only | 768 | 12 | |
|
|
| [DMRetriever-109M-PT](https://huggingface.co/DMIR01/DMRetriever-109M-PT) | Pre-trained version of 109M | BERT-base-uncased | Encoder-only | 768 | 12 | |
|
|
| [DMRetriever-335M](https://huggingface.co/DMIR01/DMRetriever-335M) | Base 335M variant | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 | |
|
|
| [DMRetriever-335M-PT](https://huggingface.co/DMIR01/DMRetriever-335M-PT) | Pre-trained version of 335M | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 | |
|
|
| [DMRetriever-596M](https://huggingface.co/DMIR01/DMRetriever-596M) | Base 596M variant | Qwen3-0.6B | Decoder-only | 1024 | 28 | |
|
|
| [DMRetriever-596M-PT](https://huggingface.co/DMIR01/DMRetriever-596M-PT) | Pre-trained version of 596M | Qwen3-0.6B | Decoder-only | 1024 | 28 | |
|
|
| [DMRetriever-4B](https://huggingface.co/DMIR01/DMRetriever-4B) | Base 4B variant | Qwen3-4B | Decoder-only | 2560 | 36 | |
|
|
| [DMRetriever-4B-PT](https://huggingface.co/DMIR01/DMRetriever-4B-PT) | Pre-trained version of 4B | Qwen3-4B | Decoder-only | 2560 | 36 | |
|
|
| [DMRetriever-7.6B](https://huggingface.co/DMIR01/DMRetriever-7.6B) | Base 7.6B variant | Qwen3-8B | Decoder-only | 4096 | 36 | |
|
|
| [DMRetriever-7.6B-PT](https://huggingface.co/DMIR01/DMRetriever-7.6B-PT) | Pre-trained version of 7.6B | Qwen3-8B | Decoder-only | 4096 | 36 | |
|
|
|
|
|
|
|
|
## 🚀 Usage |
|
|
|
|
|
Using HuggingFace Transformers: |
|
|
|
|
|
```python |
|
|
# pip install torch transformers |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer |
|
|
from bidirectional_qwen3 import Qwen3BiModel # custom bidirectional backbone |
|
|
|
|
|
MODEL_ID = "DMIR01/DMRetriever-4B" |
|
|
|
|
|
# Device & dtype |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
# --- Tokenizer (needs remote code for custom modules) --- |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
use_fast=False, |
|
|
) |
|
|
# Ensure pad token and right padding (matches training) |
|
|
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.padding_side = "right" |
|
|
|
|
|
# --- Bidirectional encoder (non-autoregressive; for retrieval/embedding) --- |
|
|
model = Qwen3BiModel.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=dtype, |
|
|
trust_remote_code=True, |
|
|
).to(device).eval() |
|
|
|
|
|
# --- Mean pooling over valid tokens --- |
|
|
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) # [B, L, 1] |
|
|
summed = (last_hidden_state * mask).sum(dim=1) # [B, H] |
|
|
counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1] |
|
|
return summed / counts |
|
|
|
|
|
# --- Batch encoder: returns L2-normalized embeddings --- |
|
|
def encode_texts(texts, batch_size=32, max_length=512): |
|
|
vecs = [] |
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i:i+batch_size] |
|
|
with torch.no_grad(): |
|
|
inputs = tokenizer( |
|
|
batch, |
|
|
max_length=max_length, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
).to(device) |
|
|
hidden = model(**inputs).last_hidden_state |
|
|
emb = mean_pool(hidden, inputs["attention_mask"]) |
|
|
emb = F.normalize(emb, p=2, dim=1) # cosine-ready |
|
|
vecs.append(emb.cpu()) |
|
|
return torch.cat(vecs, dim=0) if vecs else torch.empty(0, model.config.hidden_size) |
|
|
|
|
|
# --- Task instructions (apply to queries only) --- |
|
|
TASK2PREFIX = { |
|
|
"FactCheck": "Given the claim, retrieve most relevant document that supports or refutes the claim", |
|
|
"NLI": "Given the premise, retrieve most relevant hypothesis that is entailed by the premise", |
|
|
"QA": "Given the question, retrieve most relevant passage that best answers the question", |
|
|
"QAdoc": "Given the question, retrieve the most relevant document that answers the question", |
|
|
"STS": "Given the sentence, retrieve the sentence with the same meaning", |
|
|
"Twitter": "Given the user query, retrieve the most relevant Twitter text that meets the request", |
|
|
} |
|
|
|
|
|
def apply_task_prefix(queries, task: str): |
|
|
"""Add instruction to queries; corpus texts remain unchanged.""" |
|
|
prefix = TASK2PREFIX.get(task, "") |
|
|
if prefix: |
|
|
return [f"{prefix}: {q.strip()}" for q in queries] |
|
|
return [q.strip() for q in queries] |
|
|
|
|
|
# ========================= Usage ========================= |
|
|
# Queries need task instruction |
|
|
task = "QA" |
|
|
queries_raw = [ |
|
|
"Who wrote The Little Prince?", |
|
|
"What is the capital of France?", |
|
|
] |
|
|
queries = apply_task_prefix(queries_raw, task) |
|
|
|
|
|
# Corpus: no instruction |
|
|
corpus_passages = [ |
|
|
"The Little Prince is a novella by Antoine de Saint-Exupéry, first published in 1943.", |
|
|
"Paris is the capital and most populous city of France.", |
|
|
"Transformers are neural architectures that rely on attention mechanisms.", |
|
|
] |
|
|
|
|
|
# Encode |
|
|
query_emb = encode_texts(queries, batch_size=32, max_length=512) # [Q, H] |
|
|
corpus_emb = encode_texts(corpus_passages, batch_size=32, max_length=512) # [D, H] |
|
|
print("Query embeddings:", tuple(query_emb.shape)) |
|
|
print("Corpus embeddings:", tuple(corpus_emb.shape)) |
|
|
|
|
|
# Retrieval demo: cosine similarity via dot product (embeddings are normalized) |
|
|
scores = query_emb @ corpus_emb.T # [Q, D] |
|
|
topk = scores.topk(k=min(3, corpus_emb.size(0)), dim=1) |
|
|
|
|
|
for i, q in enumerate(queries_raw): |
|
|
print(f"\nQuery[{i}] {q}") |
|
|
for rank, (score, idx) in enumerate(zip(topk.values[i].tolist(), topk.indices[i].tolist()), start=1): |
|
|
print(f" Top{rank}: doc#{idx} | score={score:.4f} | text={corpus_passages[idx]}") |
|
|
|
|
|
|
|
|
``` |
|
|
|
|
|
## ⚠️ Notice |
|
|
|
|
|
1. The **backbone** used in DMRetriever is **Bidirectional Qwen3**, not the standard Qwen3. |
|
|
Please ensure that the `bidirectional_qwen3` module (included in the released model checkpoint folder) is correctly placed inside your model directory. |
|
|
|
|
|
2. Make sure that your **transformers** library version is **> 4.51.0** to avoid the error: |
|
|
`KeyError: 'qwen3'`. |
|
|
|
|
|
|
|
|
## 🧾 Citation |
|
|
If you find this repository helpful, please kindly consider citing the corresponding paper. Thanks! |
|
|
``` |
|
|
@article{yin2025dmretriever, |
|
|
title={DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management}, |
|
|
author={Yin, Kai and Dong, Xiangjue and Liu, Chengkai and Lin, Allen and Shi, Lingfeng and Mostafavi, Ali and Caverlee, James}, |
|
|
journal={arXiv preprint arXiv:2510.15087}, |
|
|
year={2025} |
|
|
} |
|
|
``` |
|
|
|
|
|
|
|
|
|