KaiYinTAMU commited on
Commit
83f705c
·
verified ·
1 Parent(s): 8ae9184

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +110 -3
README.md CHANGED
@@ -1,3 +1,110 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - Retrieval
7
+ - LLM
8
+ - Embedding
9
+ library_name: transformers
10
+ ---
11
+
12
+ This model is trained through the approach described in [DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management](https://www.arxiv.org/abs/2510.15087).
13
+ The associated GitHub repository is available [here](https://github.com/KaiYin97/DMRETRIEVER).
14
+ This model has 335M parameters.
15
+
16
+ ## Usage
17
+
18
+ Using HuggingFace Transformers:
19
+ ```python
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import AutoTokenizer, AutoModel
24
+
25
+ MODEL_NAME = "DMIR01/DMRetriever-335M"
26
+
27
+ # Load model/tokenizer
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ dtype = torch.float16 if device == "cuda" else torch.float32
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
31
+ # Some decoder-only models have no pad token; fall back to EOS if needed
32
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+ model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=dtype).to(device)
35
+ model.eval()
36
+
37
+ # Mean pooling over valid tokens (mask==1)
38
+ def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
39
+ mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [B, T, 1]
40
+ summed = (last_hidden_state * mask).sum(dim=1) # [B, H]
41
+ counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1]
42
+ return summed / counts # [B, H]
43
+
44
+ # Optional task prefixes (use for queries; keep corpus plain)
45
+ TASK2PREFIX = {
46
+ "FactCheck": "Given the claim, retrieve most relevant document that supports or refutes the claim",
47
+ "NLI": "Given the premise, retrieve most relevant hypothesis that is entailed by the premise",
48
+ "QA": "Given the question, retrieve most relevant passage that best answers the question",
49
+ "QAdoc": "Given the question, retrieve the most relevant document that answers the question",
50
+ "STS": "Given the sentence, retrieve the sentence with the same meaning",
51
+ "Twitter": "Given the user query, retrieve the most relevant Twitter text that meets the request",
52
+ }
53
+ def with_prefix(task: str, text: str) -> str:
54
+ p = TASK2PREFIX.get(task, "")
55
+ return f"{p}: {text}" if p else text
56
+
57
+ # Batch encode with L2 normalization (recommended for cosine/inner-product search)
58
+ @torch.inference_mode()
59
+ def encode_texts(texts, batch_size: int = 32, max_length: int = 512, normalize: bool = True):
60
+ all_embs = []
61
+ for i in range(0, len(texts), batch_size):
62
+ batch = texts[i:i + batch_size]
63
+ toks = tokenizer(
64
+ batch,
65
+ padding=True,
66
+ truncation=True,
67
+ max_length=max_length,
68
+ return_tensors="pt",
69
+ )
70
+ toks = {k: v.to(device) for k, v in toks.items()}
71
+ out = model(**toks, return_dict=True)
72
+ emb = mean_pool(out.last_hidden_state, toks["attention_mask"])
73
+ if normalize:
74
+ emb = F.normalize(emb, p=2, dim=1)
75
+ all_embs.append(emb.cpu().numpy())
76
+ return np.vstack(all_embs) if all_embs else np.empty((0, model.config.hidden_size), dtype=np.float32)
77
+
78
+ # ---- Example: plain sentences ----
79
+ sentences = [
80
+ "A cat sits on the mat.",
81
+ "The feline is resting on the rug.",
82
+ "Quantum mechanics studies matter and light.",
83
+ ]
84
+ embs = encode_texts(sentences) # shape: [N, hidden_size]
85
+ print("Embeddings shape:", embs.shape)
86
+
87
+ # Cosine similarity (embeddings are L2-normalized)
88
+ sims = embs @ embs.T
89
+ print("Cosine similarity matrix:\n", np.round(sims, 3))
90
+
91
+ # ---- Example: query with task prefix (QA) ----
92
+ qa_queries = [
93
+ with_prefix("QA", "Who wrote 'Pride and Prejudice'?"),
94
+ with_prefix("QA", "What is the capital of Japan?"),
95
+ ]
96
+ qa_embs = encode_texts(qa_queries)
97
+ print("QA Embeddings shape:", qa_embs.shape)
98
+
99
+ ```
100
+ ## Citation
101
+ If you find this repository helpful, please kindly consider citing the corresponding paper. Thanks!
102
+ ```
103
+ @article{yin2025dmretriever,
104
+ title={DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management},
105
+ author={Yin, Kai and Dong, Xiangjue and Liu, Chengkai and Lin, Allen and Shi, Lingfeng and Mostafavi, Ali and Caverlee, James},
106
+ journal={arXiv preprint arXiv:2510.15087},
107
+ year={2025}
108
+ }
109
+ ```
110
+