KaiYinTAMU commited on
Commit
39edbea
·
verified ·
1 Parent(s): 832ffd4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +177 -3
README.md CHANGED
@@ -1,3 +1,177 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - Retrieval
7
+ - LLM
8
+ - Embedding
9
+ ---
10
+
11
+ This model is trained through the approach described in [DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management](https://www.arxiv.org/abs/2510.15087).
12
+ The associated GitHub repository is available [here](https://github.com/KaiYin97/DMRETRIEVER).
13
+ This model has 4B parameters.
14
+
15
+ ## 🧠 Model Overview
16
+
17
+ **DMRetriever-596M** has the following features:
18
+
19
+ - Model Type: Text Embedding
20
+ - Supported Languages: English
21
+ - Number of Paramaters: 4B
22
+ - Context Length: 512
23
+ - Embedding Dimension: 1024
24
+
25
+ For more details, including model training, benchmark evaluation, and inference performance, please refer to our [paper](https://www.arxiv.org/abs/2510.15087), [GitHub](https://github.com/KaiYin97/DMRETRIEVER).
26
+
27
+ ## 📦 DMRetriever Series Model List
28
+
29
+ | **Model** | **Description** | **Backbone** | **Backbone Type** | **Hidden Size** | **#Layers** |
30
+ |:--|:--|:--|:--|:--:|:--:|
31
+ | [DMRetriever-33M](https://huggingface.co/DMIR01/DMRetriever-33M) | Base 33M variant | MiniLM | Encoder-only | 384 | 12 |
32
+ | [DMRetriever-33M-PT](https://huggingface.co/DMIR01/DMRetriever-33M-PT) | Pre-trained version of 33M | MiniLM | Encoder-only | 384 | 12 |
33
+ | [DMRetriever-109M](https://huggingface.co/DMIR01/DMRetriever-109M) | Base 109M variant | BERT-base-uncased | Encoder-only | 768 | 12 |
34
+ | [DMRetriever-109M-PT](https://huggingface.co/DMIR01/DMRetriever-109M-PT) | Pre-trained version of 109M | BERT-base-uncased | Encoder-only | 768 | 12 |
35
+ | [DMRetriever-335M](https://huggingface.co/DMIR01/DMRetriever-335M) | Base 335M variant | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
36
+ | [DMRetriever-335M-PT](https://huggingface.co/DMIR01/DMRetriever-335M-PT) | Pre-trained version of 335M | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
37
+ | [DMRetriever-596M](https://huggingface.co/DMIR01/DMRetriever-596M) | Base 596M variant | Qwen3-0.6B | Decoder-only | 1024 | 28 |
38
+ | [DMRetriever-596M-PT](https://huggingface.co/DMIR01/DMRetriever-596M-PT) | Pre-trained version of 596M | Qwen3-0.6B | Decoder-only | 1024 | 28 |
39
+ | [DMRetriever-4B](https://huggingface.co/DMIR01/DMRetriever-4B) | Base 4B variant | Qwen3-4B | Decoder-only | 2560 | 36 |
40
+ | [DMRetriever-4B-PT](https://huggingface.co/DMIR01/DMRetriever-4B-PT) | Pre-trained version of 4B | Qwen3-4B | Decoder-only | 2560 | 36 |
41
+ | [DMRetriever-7.6B](https://huggingface.co/DMIR01/DMRetriever-7.6B) | Base 7.6B variant | Qwen3-8B | Decoder-only | 4096 | 36 |
42
+ | [DMRetriever-7.6B-PT](https://huggingface.co/DMIR01/DMRetriever-7.6B-PT) | Pre-trained version of 7.6B | Qwen3-8B | Decoder-only | 4096 | 36 |
43
+
44
+
45
+ ## 🚀 Usage
46
+
47
+ Using HuggingFace Transformers:
48
+
49
+ ```python
50
+ # pip install torch transformers
51
+ import torch
52
+ import torch.nn.functional as F
53
+ from transformers import AutoTokenizer
54
+ from bidirectional_qwen3 import Qwen3BiModel # custom bidirectional backbone
55
+
56
+ MODEL_ID = "DMIR01/DMRetriever-4B"
57
+
58
+ # Device & dtype
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ dtype = torch.float16 if device == "cuda" else torch.float32
61
+
62
+ # --- Tokenizer (needs remote code for custom modules) ---
63
+ tokenizer = AutoTokenizer.from_pretrained(
64
+ MODEL_ID,
65
+ trust_remote_code=True,
66
+ use_fast=False,
67
+ )
68
+ # Ensure pad token and right padding (matches training)
69
+ if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None:
70
+ tokenizer.pad_token = tokenizer.eos_token
71
+ tokenizer.padding_side = "right"
72
+
73
+ # --- Bidirectional encoder (non-autoregressive; for retrieval/embedding) ---
74
+ model = Qwen3BiModel.from_pretrained(
75
+ MODEL_ID,
76
+ torch_dtype=dtype,
77
+ trust_remote_code=True,
78
+ ).to(device).eval()
79
+
80
+ # --- Mean pooling over valid tokens ---
81
+ def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
82
+ mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) # [B, L, 1]
83
+ summed = (last_hidden_state * mask).sum(dim=1) # [B, H]
84
+ counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1]
85
+ return summed / counts
86
+
87
+ # --- Batch encoder: returns L2-normalized embeddings ---
88
+ def encode_texts(texts, batch_size=32, max_length=512):
89
+ vecs = []
90
+ for i in range(0, len(texts), batch_size):
91
+ batch = texts[i:i+batch_size]
92
+ with torch.no_grad():
93
+ inputs = tokenizer(
94
+ batch,
95
+ max_length=max_length,
96
+ truncation=True,
97
+ padding=True,
98
+ return_tensors="pt",
99
+ ).to(device)
100
+ hidden = model(**inputs).last_hidden_state
101
+ emb = mean_pool(hidden, inputs["attention_mask"])
102
+ emb = F.normalize(emb, p=2, dim=1) # cosine-ready
103
+ vecs.append(emb.cpu())
104
+ return torch.cat(vecs, dim=0) if vecs else torch.empty(0, model.config.hidden_size)
105
+
106
+ # --- Task instructions (apply to queries only) ---
107
+ TASK2PREFIX = {
108
+ "FactCheck": "Given the claim, retrieve most relevant document that supports or refutes the claim",
109
+ "NLI": "Given the premise, retrieve most relevant hypothesis that is entailed by the premise",
110
+ "QA": "Given the question, retrieve most relevant passage that best answers the question",
111
+ "QAdoc": "Given the question, retrieve the most relevant document that answers the question",
112
+ "STS": "Given the sentence, retrieve the sentence with the same meaning",
113
+ "Twitter": "Given the user query, retrieve the most relevant Twitter text that meets the request",
114
+ }
115
+
116
+ def apply_task_prefix(queries, task: str):
117
+ """Add instruction to queries; corpus texts remain unchanged."""
118
+ prefix = TASK2PREFIX.get(task, "")
119
+ if prefix:
120
+ return [f"{prefix}: {q.strip()}" for q in queries]
121
+ return [q.strip() for q in queries]
122
+
123
+ # ========================= Usage =========================
124
+ # Queries need task instruction
125
+ task = "QA"
126
+ queries_raw = [
127
+ "Who wrote The Little Prince?",
128
+ "What is the capital of France?",
129
+ ]
130
+ queries = apply_task_prefix(queries_raw, task)
131
+
132
+ # Corpus: no instruction
133
+ corpus_passages = [
134
+ "The Little Prince is a novella by Antoine de Saint-Exupéry, first published in 1943.",
135
+ "Paris is the capital and most populous city of France.",
136
+ "Transformers are neural architectures that rely on attention mechanisms.",
137
+ ]
138
+
139
+ # Encode
140
+ query_emb = encode_texts(queries, batch_size=32, max_length=512) # [Q, H]
141
+ corpus_emb = encode_texts(corpus_passages, batch_size=32, max_length=512) # [D, H]
142
+ print("Query embeddings:", tuple(query_emb.shape))
143
+ print("Corpus embeddings:", tuple(corpus_emb.shape))
144
+
145
+ # Retrieval demo: cosine similarity via dot product (embeddings are normalized)
146
+ scores = query_emb @ corpus_emb.T # [Q, D]
147
+ topk = scores.topk(k=min(3, corpus_emb.size(0)), dim=1)
148
+
149
+ for i, q in enumerate(queries_raw):
150
+ print(f"\nQuery[{i}] {q}")
151
+ for rank, (score, idx) in enumerate(zip(topk.values[i].tolist(), topk.indices[i].tolist()), start=1):
152
+ print(f" Top{rank}: doc#{idx} | score={score:.4f} | text={corpus_passages[idx]}")
153
+
154
+
155
+ ```
156
+
157
+ ## ⚠️ Notice
158
+
159
+ 1. The **backbone** used in DMRetriever is **Bidirectional Qwen3**, not the standard Qwen3.
160
+ Please ensure that the `bidirectional_qwen3` module (included in the released model checkpoint folder) is correctly placed inside your model directory.
161
+
162
+ 2. Make sure that your **transformers** library version is **≥ 4.51.0** to avoid the error:
163
+ `KeyError: 'qwen3'`.
164
+
165
+
166
+ ## 🧾 Citation
167
+ If you find this repository helpful, please kindly consider citing the corresponding paper. Thanks!
168
+ ```
169
+ @article{yin2025dmretriever,
170
+ title={DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management},
171
+ author={Yin, Kai and Dong, Xiangjue and Liu, Chengkai and Lin, Allen and Shi, Lingfeng and Mostafavi, Ali and Caverlee, James},
172
+ journal={arXiv preprint arXiv:2510.15087},
173
+ year={2025}
174
+ }
175
+ ```
176
+
177
+