Text Ranking
Transformers
Safetensors
multilingual
t5gemma2
text2text-generation
reranker
encoder-decoder
FBNL
Retrieval
RAG
cosyy commited on
Commit
c73f423
·
verified ·
1 Parent(s): 1dcb875

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +88 -10
README.md CHANGED
@@ -91,20 +91,94 @@ On LMEB, reranking models demonstrate a clear advantage, with even the 0.27B Nan
91
  # Usage
92
  ```python
93
  import argparse
94
-
95
- from kalm_reranker import KaLMReranker
96
-
97
-
98
- def main() -> None:
99
- parser = argparse.ArgumentParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  parser.add_argument(
101
  "--model",
102
- default="KaLM-Embedding/KaLM-Reranker-V1-Small"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
- parser.add_argument("--device", default=None)
105
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- reranker = KaLMReranker(args.model, device=args.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  query = "What is the capital of China?"
109
  documents = [
110
  "The capital of China is Beijing.",
@@ -115,11 +189,15 @@ def main() -> None:
115
  pairs = [(query, document) for document in documents]
116
  print("scores:", reranker.predict(pairs, instruction=instruction))
117
  print("rankings:", reranker.rank(query, documents, instruction=instruction))
 
 
 
118
 
119
 
120
  if __name__ == "__main__":
121
  main()
122
 
 
123
  ```
124
 
125
  # Citation
 
91
  # Usage
92
  ```python
93
  import argparse
94
+ from typing import Optional
95
+
96
+
97
+ def optional_positive_int(value: str) -> Optional[int]:
98
+ if value.lower() == "none":
99
+ return None
100
+ try:
101
+ parsed = int(value)
102
+ except ValueError as error:
103
+ raise argparse.ArgumentTypeError(
104
+ "must be a positive integer or 'none'"
105
+ ) from error
106
+ if parsed <= 0:
107
+ raise argparse.ArgumentTypeError("must be a positive integer or 'none'")
108
+ return parsed
109
+
110
+
111
+ def build_parser() -> argparse.ArgumentParser:
112
+ parser = argparse.ArgumentParser(
113
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
114
+ )
115
  parser.add_argument(
116
  "--model",
117
+ default="KaLM-Embedding/KaLM-Reranker-V1-Small",
118
+ help="Hugging Face model ID or local checkpoint path.",
119
+ )
120
+ parser.add_argument(
121
+ "--device",
122
+ default=None,
123
+ help="Inference device, such as 'cuda', 'cuda:0', or 'cpu'.",
124
+ )
125
+ parser.add_argument(
126
+ "--dtype",
127
+ default=None,
128
+ choices=("bfloat16", "bf16", "float16", "fp16", "float32", "fp32"),
129
+ help="Model parameter dtype. By default, use BF16 on CUDA and FP32 on CPU.",
130
+ )
131
+ parser.add_argument(
132
+ "--batch-size",
133
+ type=int,
134
+ default=32,
135
+ help="Number of query-document pairs scored per inference batch.",
136
  )
137
+ parser.add_argument(
138
+ "--query-max-length",
139
+ type=int,
140
+ default=512,
141
+ help=(
142
+ "Maximum tokens in the raw query before it is inserted into the "
143
+ "decoder prompt; prompt tokens are not included in this limit."
144
+ ),
145
+ )
146
+ parser.add_argument(
147
+ "--reranker-max-length",
148
+ type=int,
149
+ default=1024,
150
+ help=(
151
+ "Maximum encoder tokens for '<Document>: {passage}'. This is not a "
152
+ "combined query-document context limit."
153
+ ),
154
+ )
155
+ parser.add_argument(
156
+ "--chunk-size",
157
+ type=optional_positive_int,
158
+ default=4,
159
+ metavar="N|none",
160
+ help=(
161
+ "Number of encoder token hidden states per mean-pooled chunk; use "
162
+ "'none' to disable encoder chunk pooling."
163
+ ),
164
+ )
165
+ return parser
166
+
167
 
168
+ def main() -> None:
169
+ args = build_parser().parse_args()
170
+
171
+ from kalm_reranker import KaLMReranker
172
+
173
+ reranker = KaLMReranker(
174
+ args.model,
175
+ device=args.device,
176
+ dtype=args.dtype,
177
+ batch_size=args.batch_size,
178
+ query_max_length=args.query_max_length,
179
+ max_length=args.reranker_max_length,
180
+ chunk_size=args.chunk_size,
181
+ )
182
  query = "What is the capital of China?"
183
  documents = [
184
  "The capital of China is Beijing.",
 
189
  pairs = [(query, document) for document in documents]
190
  print("scores:", reranker.predict(pairs, instruction=instruction))
191
  print("rankings:", reranker.rank(query, documents, instruction=instruction))
192
+
193
+ # scores: [0.9999822378158569, 3.187565198459197e-06]
194
+ # rankings: [{'corpus_id': 0, 'score': 0.9999822378158569}, {'corpus_id': 1, 'score': 3.187565198459197e-06}]
195
 
196
 
197
  if __name__ == "__main__":
198
  main()
199
 
200
+
201
  ```
202
 
203
  # Citation