cstr commited on
Commit
b19c63d
·
verified ·
1 Parent(s): b78d236

Upload export_zerank_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. export_zerank_v2.py +216 -0
export_zerank_v2.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Re-export zerank-1-small with dynamic batch support.
4
+
5
+ Key change from v1: ZeRankScorerV2 builds the 4D causal+padding attention mask
6
+ explicitly using input_ids.shape[0] (dynamic). This makes the batch dimension
7
+ symbolic in the ONNX graph — batch > 1 works correctly.
8
+
9
+ Also bakes the Qwen3 chat template into the expected input format:
10
+ "<|im_start|>user\\nQuery: {q}\\nDocument: {d}\\nRelevant:<|im_end|>\\n<|im_start|>assistant\\n"
11
+
12
+ Tokenize the formatted string as a SINGLE sequence (not a pair) in fastembed.
13
+
14
+ Output:
15
+ /private/tmp/zerank_export/zerank_onnx_v2/model.onnx + model.onnx_data (FP16)
16
+ (INT8/INT4 re-quantization: run stream_int8.py and export_int4.py after this)
17
+ """
18
+
19
+ import gc
20
+ from pathlib import Path
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ MODEL_ID = "zeroentropy/zerank-1-small"
26
+ YES_TOKEN_ID = 9454
27
+
28
+ OUT_DIR = Path("/private/tmp/zerank_export/zerank_onnx_v2")
29
+ OUT_MODEL = OUT_DIR / "model.onnx"
30
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
31
+
32
+
33
+ class ZeRankScorerV2(nn.Module):
34
+ """
35
+ Wraps Qwen3ForCausalLM + last-token Yes-logit extraction.
36
+
37
+ Difference from V1: builds 4D causal+padding mask explicitly so the batch
38
+ dimension is dynamic in the ONNX graph (V1 had it hardcoded to 1).
39
+
40
+ Input:
41
+ input_ids [batch, seq] — pre-formatted with chat template
42
+ attention_mask [batch, seq] — 1 for real tokens, 0 for padding
43
+
44
+ Output:
45
+ logits [batch, 1] — raw Yes-token logit, higher = more relevant
46
+ """
47
+ def __init__(self, base_model, yes_token_id: int):
48
+ super().__init__()
49
+ self.base = base_model
50
+ self.yes_token_id = yes_token_id
51
+ self._dtype = next(base_model.parameters()).dtype
52
+
53
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
54
+ batch_size = input_ids.shape[0]
55
+ seq_len = input_ids.shape[1]
56
+ device = input_ids.device
57
+ min_val = torch.finfo(self._dtype).min
58
+
59
+ # Causal mask: upper-triangular = min_val, lower-triangular = 0
60
+ # Shape [1, 1, seq, seq] → expand to [batch, 1, seq, seq]
61
+ upper = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device).triu(diagonal=1)
62
+ causal = torch.zeros(1, 1, seq_len, seq_len, dtype=self._dtype, device=device)
63
+ causal = causal.masked_fill(upper.view(1, 1, seq_len, seq_len), min_val)
64
+ causal = causal.expand(batch_size, 1, seq_len, seq_len)
65
+
66
+ # Padding mask: positions with attention_mask=0 get min_val
67
+ pad = (1.0 - attention_mask.to(self._dtype)) * min_val # [batch, seq]
68
+ pad = pad.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq]
69
+ pad = pad.expand(batch_size, 1, seq_len, seq_len)
70
+
71
+ full_mask = causal + pad
72
+
73
+ # Transformer body → [batch, seq, hidden]
74
+ hidden = self.base.model(
75
+ input_ids=input_ids,
76
+ attention_mask=full_mask,
77
+ )[0]
78
+
79
+ # Gather at last real-token position: sum(mask) - 1
80
+ last_pos = attention_mask.sum(dim=-1) - 1 # [batch]
81
+ idx = last_pos.view(-1, 1, 1).expand(-1, 1, hidden.shape[-1])
82
+ last_hidden = torch.gather(hidden, 1, idx).squeeze(1) # [batch, hidden]
83
+
84
+ yes_logit = self.base.lm_head(last_hidden)[:, self.yes_token_id] # [batch]
85
+ return yes_logit.unsqueeze(-1) # [batch, 1]
86
+
87
+
88
+ def run_export():
89
+ from transformers import Qwen3ForCausalLM, AutoTokenizer
90
+ import torch.onnx as torch_onnx
91
+
92
+ print(f"Loading {MODEL_ID}...")
93
+ model = Qwen3ForCausalLM.from_pretrained(
94
+ MODEL_ID,
95
+ torch_dtype=torch.float16,
96
+ low_cpu_mem_usage=True,
97
+ attn_implementation="eager",
98
+ ).eval()
99
+
100
+ scorer = ZeRankScorerV2(model, YES_TOKEN_ID).eval()
101
+
102
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
103
+
104
+ # Dummy batch=2 — forces dynamic batch to trace correctly
105
+ template = "<|im_start|>user\nQuery: {q}\nDocument: {d}\nRelevant:<|im_end|>\n<|im_start|>assistant\n"
106
+ pairs = [
107
+ ("what is a panda?", "A panda is a large black-and-white bear."),
108
+ ("what is a cat?", "A cat is a small domesticated carnivorous mammal."),
109
+ ]
110
+ formatted = [template.format(q=q, d=d) for q, d in pairs]
111
+ enc = tok(formatted, padding=True, truncation=True, max_length=64, return_tensors="pt")
112
+ dummy_ids = enc["input_ids"]
113
+ dummy_mask = enc["attention_mask"]
114
+ print(f" Dummy batch shape: {dummy_ids.shape}")
115
+
116
+ # Verify correct batch behaviour before exporting
117
+ with torch.no_grad():
118
+ out_batch = scorer(dummy_ids, dummy_mask)
119
+ out_single = scorer(dummy_ids[:1], dummy_mask[:1])
120
+ assert abs(float(out_batch[0, 0]) - float(out_single[0, 0])) < 0.01, \
121
+ f"Batch/single mismatch: {float(out_batch[0,0]):.3f} vs {float(out_single[0,0]):.3f}"
122
+ print(f" Batch consistency check PASS: {float(out_batch[0,0]):.3f} vs {float(out_single[0,0]):.3f}")
123
+
124
+ print(f"Exporting to {OUT_MODEL} ...")
125
+ with torch.no_grad():
126
+ torch_onnx.export(
127
+ scorer,
128
+ (dummy_ids, dummy_mask),
129
+ str(OUT_MODEL),
130
+ input_names=["input_ids", "attention_mask"],
131
+ output_names=["logits"],
132
+ dynamic_axes={
133
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
134
+ "attention_mask": {0: "batch_size", 1: "sequence_length"},
135
+ "logits": {0: "batch_size"},
136
+ },
137
+ opset_version=18,
138
+ do_constant_folding=False,
139
+ )
140
+
141
+ import onnx
142
+ from onnx.external_data_helper import convert_model_to_external_data
143
+ print(" Converting to external data format...")
144
+ m = onnx.load(str(OUT_MODEL))
145
+ convert_model_to_external_data(
146
+ m, all_tensors_to_one_file=True,
147
+ location="model.onnx_data", size_threshold=1024,
148
+ )
149
+ onnx.save(m, str(OUT_MODEL))
150
+ print("Export complete:")
151
+ for f in sorted(OUT_DIR.iterdir()):
152
+ print(f" {f.name:40s} {f.stat().st_size / 1e6:.0f} MB")
153
+
154
+ del m, scorer, model, tok, enc, dummy_ids, dummy_mask
155
+ gc.collect()
156
+
157
+
158
+ def verify_batch():
159
+ import onnxruntime as ort
160
+
161
+ print(f"\nVerifying batch > 1...")
162
+ sess = ort.InferenceSession(str(OUT_MODEL), providers=["CPUExecutionProvider"])
163
+
164
+ from transformers import AutoTokenizer
165
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
166
+ template = "<|im_start|>user\nQuery: {q}\nDocument: {d}\nRelevant:<|im_end|>\n<|im_start|>assistant\n"
167
+
168
+ q = "what is a panda?"
169
+ docs = [
170
+ "The giant panda is a bear species endemic to China.",
171
+ "The sky is blue.",
172
+ "panda is an animal",
173
+ ]
174
+
175
+ # Single inference
176
+ single_scores = []
177
+ for d in docs:
178
+ fmt = template.format(q=q, d=d)
179
+ enc = tok(fmt, return_tensors="np", truncation=True, max_length=256)
180
+ logit = sess.run(["logits"], {
181
+ "input_ids": enc["input_ids"].astype(np.int64),
182
+ "attention_mask": enc["attention_mask"].astype(np.int64),
183
+ })[0]
184
+ single_scores.append(float(logit[0, 0]))
185
+
186
+ # Batch inference
187
+ formatted = [template.format(q=q, d=d) for d in docs]
188
+ enc = tok(formatted, return_tensors="np", truncation=True, max_length=256, padding=True)
189
+ logits = sess.run(["logits"], {
190
+ "input_ids": enc["input_ids"].astype(np.int64),
191
+ "attention_mask": enc["attention_mask"].astype(np.int64),
192
+ })[0]
193
+ batch_scores = [float(logits[i, 0]) for i in range(len(docs))]
194
+
195
+ print(" Single vs batch scores:")
196
+ for d, s, b in zip(docs, single_scores, batch_scores):
197
+ diff = abs(s - b)
198
+ print(f" [{s:.3f} vs {b:.3f}] diff={diff:.4f} | {d[:50]}")
199
+ assert diff < 0.1, f"Mismatch too large: {diff}"
200
+ assert batch_scores[0] > batch_scores[1], "Panda should rank higher than sky"
201
+ print(" OK — batch scores match single, correct ranking")
202
+
203
+
204
+ if __name__ == "__main__":
205
+ if OUT_MODEL.exists():
206
+ print(f"Model already exists at {OUT_MODEL}, skipping export.")
207
+ print("Delete it to re-export.")
208
+ else:
209
+ run_export()
210
+ gc.collect()
211
+
212
+ verify_batch()
213
+
214
+ print("\nNext steps:")
215
+ print(f" 1. Run stream_int8_v2.py to quantize INT8 from {OUT_MODEL}")
216
+ print(f" 2. Upload to HF: huggingface-cli upload cstr/zerank-1-small-ONNX {OUT_DIR}/ . --repo-type model")