File size: 6,715 Bytes
a0fe500 88e0c2f a0fe500 affa9c7 88e0c2f 3676d31 88e0c2f a0fe500 04fb72b a0fe500 04fb72b a0fe500 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
---
license: apache-2.0
language:
- en
tags:
- information-retrieval
- LLM
- Embedding
- text-retrieval
- disaster-management
task_categories:
- text-retrieval
library_name: transformers
dataset_tags:
- DMIR01/DMRetriever_MTT
---
This model is trained through the approach described in [DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management](https://www.arxiv.org/abs/2510.15087).
The associated GitHub repository is available [here](https://github.com/KaiYin97/DMRETRIEVER).
This model has 33M parameters.
## 🧠 Model Overview
**DMRetriever-33M** has the following features:
- Model Type: Text Embedding
- Supported Languages: English
- Number of Paramaters: 33M
- Embedding Dimension: 384
For more details, including model training, benchmark evaluation, and inference performance, please refer to our [paper](https://www.arxiv.org/abs/2510.15087), [GitHub](https://github.com/KaiYin97/DMRETRIEVER).
## 📦 DMRetriever Series Model List
| **Model** | **Description** | **Backbone** | **Backbone Type** | **Hidden Size** | **#Layers** |
|:--|:--|:--|:--|:--:|:--:|
| [DMRetriever-33M](https://huggingface.co/DMIR01/DMRetriever-33M) | Base 33M variant | MiniLM | Encoder-only | 384 | 12 |
| [DMRetriever-33M-PT](https://huggingface.co/DMIR01/DMRetriever-33M-PT) | Pre-trained version of 33M | MiniLM | Encoder-only | 384 | 12 |
| [DMRetriever-109M](https://huggingface.co/DMIR01/DMRetriever-109M) | Base 109M variant | BERT-base-uncased | Encoder-only | 768 | 12 |
| [DMRetriever-109M-PT](https://huggingface.co/DMIR01/DMRetriever-109M-PT) | Pre-trained version of 109M | BERT-base-uncased | Encoder-only | 768 | 12 |
| [DMRetriever-335M](https://huggingface.co/DMIR01/DMRetriever-335M) | Base 335M variant | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
| [DMRetriever-335M-PT](https://huggingface.co/DMIR01/DMRetriever-335M-PT) | Pre-trained version of 335M | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
| [DMRetriever-596M](https://huggingface.co/DMIR01/DMRetriever-596M) | Base 596M variant | Qwen3-0.6B | Decoder-only | 1024 | 28 |
| [DMRetriever-596M-PT](https://huggingface.co/DMIR01/DMRetriever-596M-PT) | Pre-trained version of 596M | Qwen3-0.6B | Decoder-only | 1024 | 28 |
| [DMRetriever-4B](https://huggingface.co/DMIR01/DMRetriever-4B) | Base 4B variant | Qwen3-4B | Decoder-only | 2560 | 36 |
| [DMRetriever-4B-PT](https://huggingface.co/DMIR01/DMRetriever-4B-PT) | Pre-trained version of 4B | Qwen3-4B | Decoder-only | 2560 | 36 |
| [DMRetriever-7.6B](https://huggingface.co/DMIR01/DMRetriever-7.6B) | Base 7.6B variant | Qwen3-8B | Decoder-only | 4096 | 36 |
| [DMRetriever-7.6B-PT](https://huggingface.co/DMIR01/DMRetriever-7.6B-PT) | Pre-trained version of 7.6B | Qwen3-8B | Decoder-only | 4096 | 36 |
## 🚀 Usage
Using HuggingFace Transformers:
```python
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
MODEL_NAME = "DMIR01/DMRetriever-33M"
# Load model/tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
# Some decoder-only models have no pad token; fall back to EOS if needed
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=dtype).to(device)
model.eval()
# Mean pooling over valid tokens (mask==1)
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [B, T, 1]
summed = (last_hidden_state * mask).sum(dim=1) # [B, H]
counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1]
return summed / counts # [B, H]
# Optional task prefixes (use for queries; keep corpus plain)
TASK2PREFIX = {
"FactCheck": "Given the claim, retrieve most relevant document that supports or refutes the claim",
"NLI": "Given the premise, retrieve most relevant hypothesis that is entailed by the premise",
"QA": "Given the question, retrieve most relevant passage that best answers the question",
"QAdoc": "Given the question, retrieve the most relevant document that answers the question",
"STS": "Given the sentence, retrieve the sentence with the same meaning",
"Twitter": "Given the user query, retrieve the most relevant Twitter text that meets the request",
}
def with_prefix(task: str, text: str) -> str:
p = TASK2PREFIX.get(task, "")
return f"{p}: {text}" if p else text
# Batch encode with L2 normalization (recommended for cosine/inner-product search)
@torch.inference_mode()
def encode_texts(texts, batch_size: int = 32, max_length: int = 512, normalize: bool = True):
all_embs = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
toks = tokenizer(
batch,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
toks = {k: v.to(device) for k, v in toks.items()}
out = model(**toks, return_dict=True)
emb = mean_pool(out.last_hidden_state, toks["attention_mask"])
if normalize:
emb = F.normalize(emb, p=2, dim=1)
all_embs.append(emb.cpu().numpy())
return np.vstack(all_embs) if all_embs else np.empty((0, model.config.hidden_size), dtype=np.float32)
# ---- Example: plain sentences ----
sentences = [
"A cat sits on the mat.",
"The feline is resting on the rug.",
"Quantum mechanics studies matter and light.",
]
embs = encode_texts(sentences) # shape: [N, hidden_size]
print("Embeddings shape:", embs.shape)
# Cosine similarity (embeddings are L2-normalized)
sims = embs @ embs.T
print("Cosine similarity matrix:\n", np.round(sims, 3))
# ---- Example: query with task prefix (QA) ----
qa_queries = [
with_prefix("QA", "Who wrote 'Pride and Prejudice'?"),
with_prefix("QA", "What is the capital of Japan?"),
]
qa_embs = encode_texts(qa_queries)
print("QA Embeddings shape:", qa_embs.shape)
```
## 🧾 Citation
If you find this repository helpful, please kindly consider citing the corresponding paper. Thanks!
```
@article{yin2025dmretriever,
title={DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management},
author={Yin, Kai and Dong, Xiangjue and Liu, Chengkai and Lin, Allen and Shi, Lingfeng and Mostafavi, Ali and Caverlee, James},
journal={arXiv preprint arXiv:2510.15087},
year={2025}
}
```
|