cstr commited on
Commit
ba9c38b
·
verified ·
1 Parent(s): 73e1c32

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +38 -17
README.md CHANGED
@@ -25,18 +25,31 @@ ONNX export of [zeroentropy/zerank-1-small](https://huggingface.co/zeroentropy/z
25
  | `model_int8.onnx` + `model_int8.onnx_data` | INT8 | ~2.5 GB | Weight-only INT8 (per-tensor symmetric) |
26
  | `model_int4_full.onnx` | INT4 | ~1.3 GB | MatMulNBits INT4, block_size=32 |
27
 
28
- Conversion scripts: `export_zerank.py` (FP16 export), `stream_int8.py` (INT8 quantization).
29
 
30
  ## ⚠️ Important: chat template required
31
 
32
  This model is a Qwen3-based causal LM that scores (query, document) relevance by extracting the **"Yes" token logit** at the last position. It requires a specific prompt format — plain pair tokenization produces meaningless scores.
33
 
34
- **Always format inputs as:**
 
 
 
 
 
 
 
 
35
  ```
 
 
 
 
 
 
36
  <|im_start|>user
37
- Query: {query}
38
- Document: {document}
39
- Relevant:<|im_end|>
40
  <|im_start|>assistant
41
  ```
42
 
@@ -48,16 +61,23 @@ import numpy as np
48
  from transformers import AutoTokenizer
49
 
50
  MODEL_PATH = "model_int8.onnx" # or model.onnx, model_int4_full.onnx
51
- TEMPLATE = "<|im_start|>user\nQuery: {query}\nDocument: {doc}\nRelevant:<|im_end|>\n<|im_start|>assistant\n"
52
 
53
  sess = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
54
  tok = AutoTokenizer.from_pretrained("cstr/zerank-1-small-ONNX")
55
 
 
 
 
 
 
 
 
56
  def rerank(query: str, documents: list[str]) -> list[float]:
57
  scores = []
58
  for doc in documents:
59
- text = TEMPLATE.format(query=query, doc=doc)
60
- enc = tok(text, return_tensors="np", truncation=True, max_length=512)
61
  logit = sess.run(["logits"], {
62
  "input_ids": enc["input_ids"].astype(np.int64),
63
  "attention_mask": enc["attention_mask"].astype(np.int64),
@@ -74,12 +94,12 @@ docs = [
74
  scores = rerank(query, docs)
75
  for s, d in sorted(zip(scores, docs), reverse=True):
76
  print(f"[{s:.3f}] {d}")
77
- # [6.8] The giant panda is a bear species endemic to China.
78
- # [2.1] Pandas are mammals in the family Ursidae.
79
  # [-5.8] The sky is blue and the grass is green.
80
  ```
81
 
82
- > **Note:** Current export uses `batch_size=1` (causal mask is static). Process documents one at a time as shown above.
83
 
84
  ## Usage with fastembed-rs
85
 
@@ -90,7 +110,7 @@ let mut reranker = TextRerank::try_new(
90
  RerankInitOptions::new(RerankerModel::ZerankSmallInt8)
91
  ).unwrap();
92
 
93
- // batch_size=1: chat template is applied automatically per document
94
  let results = reranker.rerank(
95
  "What is a panda?",
96
  vec![
@@ -99,7 +119,7 @@ let results = reranker.rerank(
99
  "Pandas are mammals in the family Ursidae.",
100
  ],
101
  true,
102
- Some(1),
103
  ).unwrap();
104
 
105
  for r in &results {
@@ -109,11 +129,12 @@ for r in &results {
109
 
110
  ## Export details
111
 
112
- `export_zerank.py` wraps Qwen3ForCausalLM in a `ZeRankScorer` that:
113
 
114
- 1. Runs the transformer body `hidden [batch, seq, hidden]`
115
- 2. Gathers the hidden state at the last real-token position (`attention_mask.sum - 1`)
116
- 3. Applies `lm_head`, slices the **"Yes" token** (id `9454`) `[batch, 1]`
 
117
 
118
  Output: `logits [batch, 1]` — raw Yes-token logit (higher = more relevant). FP16 weights, opset 18.
119
 
 
25
  | `model_int8.onnx` + `model_int8.onnx_data` | INT8 | ~2.5 GB | Weight-only INT8 (per-tensor symmetric) |
26
  | `model_int4_full.onnx` | INT4 | ~1.3 GB | MatMulNBits INT4, block_size=32 |
27
 
28
+ Conversion scripts: `export_zerank_v2.py` (FP16 export with dynamic batch), `stream_int8.py` (INT8 quantization).
29
 
30
  ## ⚠️ Important: chat template required
31
 
32
  This model is a Qwen3-based causal LM that scores (query, document) relevance by extracting the **"Yes" token logit** at the last position. It requires a specific prompt format — plain pair tokenization produces meaningless scores.
33
 
34
+ **Always format inputs using the Qwen3 chat template with `system=query`, `user=document`:**
35
+
36
+ ```python
37
+ # using the tokenizer directly (matches training format exactly):
38
+ messages = [
39
+ {"role": "system", "content": query},
40
+ {"role": "user", "content": document},
41
+ ]
42
+ text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
43
  ```
44
+
45
+ This produces the following fixed string (equivalent, usable without a tokenizer):
46
+ ```
47
+ <|im_start|>system
48
+ {query}
49
+ <|im_end|>
50
  <|im_start|>user
51
+ {document}
52
+ <|im_end|>
 
53
  <|im_start|>assistant
54
  ```
55
 
 
61
  from transformers import AutoTokenizer
62
 
63
  MODEL_PATH = "model_int8.onnx" # or model.onnx, model_int4_full.onnx
64
+ MAX_LENGTH = 512
65
 
66
  sess = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
67
  tok = AutoTokenizer.from_pretrained("cstr/zerank-1-small-ONNX")
68
 
69
+ def format_pair(query: str, doc: str) -> str:
70
+ messages = [
71
+ {"role": "system", "content": query},
72
+ {"role": "user", "content": doc},
73
+ ]
74
+ return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75
+
76
  def rerank(query: str, documents: list[str]) -> list[float]:
77
  scores = []
78
  for doc in documents:
79
+ text = format_pair(query, doc)
80
+ enc = tok(text, return_tensors="np", truncation=True, max_length=MAX_LENGTH)
81
  logit = sess.run(["logits"], {
82
  "input_ids": enc["input_ids"].astype(np.int64),
83
  "attention_mask": enc["attention_mask"].astype(np.int64),
 
94
  scores = rerank(query, docs)
95
  for s, d in sorted(zip(scores, docs), reverse=True):
96
  print(f"[{s:.3f}] {d}")
97
+ # [+6.8] The giant panda is a bear species endemic to China.
98
+ # [+2.1] Pandas are mammals in the family Ursidae.
99
  # [-5.8] The sky is blue and the grass is green.
100
  ```
101
 
102
+ > **Batch inference:** The v2 export (`model.onnx`) supports `batch_size > 1` via a dynamic causal+padding mask. Pad a batch with the tokenizer and pass the full batch at once for higher throughput.
103
 
104
  ## Usage with fastembed-rs
105
 
 
110
  RerankInitOptions::new(RerankerModel::ZerankSmallInt8)
111
  ).unwrap();
112
 
113
+ // The chat template is applied automatically; batch_size > 1 is supported.
114
  let results = reranker.rerank(
115
  "What is a panda?",
116
  vec![
 
119
  "Pandas are mammals in the family Ursidae.",
120
  ],
121
  true,
122
+ Some(32),
123
  ).unwrap();
124
 
125
  for r in &results {
 
129
 
130
  ## Export details
131
 
132
+ `export_zerank_v2.py` wraps Qwen3ForCausalLM in a `ZeRankScorerV2` that:
133
 
134
+ 1. Builds a 4D causal+padding attention mask explicitly from `input_ids.shape[0]` — this makes the batch dimension dynamic in the ONNX graph (enabling `batch_size > 1`).
135
+ 2. Runs the transformer body `hidden [batch, seq, hidden]`
136
+ 3. Gathers the hidden state at the last real-token position (`attention_mask.sum - 1`)
137
+ 4. Applies `lm_head`, slices the **"Yes" token** (id `9454`) → `[batch, 1]`
138
 
139
  Output: `logits [batch, 1]` — raw Yes-token logit (higher = more relevant). FP16 weights, opset 18.
140