| |
| """ |
| Re-export zerank-1-small with dynamic batch support. |
| |
| Key change from v1: ZeRankScorerV2 builds the 4D causal+padding attention mask |
| explicitly using input_ids.shape[0] (dynamic). This makes the batch dimension |
| symbolic in the ONNX graph — batch > 1 works correctly. |
| |
| Also bakes the Qwen3 chat template into the expected input format: |
| "<|im_start|>user\\nQuery: {q}\\nDocument: {d}\\nRelevant:<|im_end|>\\n<|im_start|>assistant\\n" |
| |
| Tokenize the formatted string as a SINGLE sequence (not a pair) in fastembed. |
| |
| Output: |
| /private/tmp/zerank_export/zerank_onnx_v2/model.onnx + model.onnx_data (FP16) |
| (INT8/INT4 re-quantization: run stream_int8.py and export_int4.py after this) |
| """ |
|
|
| import gc |
| from pathlib import Path |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| MODEL_ID = "zeroentropy/zerank-1-small" |
| YES_TOKEN_ID = 9454 |
|
|
| OUT_DIR = Path("/private/tmp/zerank_export/zerank_onnx_v2") |
| OUT_MODEL = OUT_DIR / "model.onnx" |
| OUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| class ZeRankScorerV2(nn.Module): |
| """ |
| Wraps Qwen3ForCausalLM + last-token Yes-logit extraction. |
| |
| Difference from V1: builds 4D causal+padding mask explicitly so the batch |
| dimension is dynamic in the ONNX graph (V1 had it hardcoded to 1). |
| |
| Input: |
| input_ids [batch, seq] — pre-formatted with chat template |
| attention_mask [batch, seq] — 1 for real tokens, 0 for padding |
| |
| Output: |
| logits [batch, 1] — raw Yes-token logit, higher = more relevant |
| """ |
| def __init__(self, base_model, yes_token_id: int): |
| super().__init__() |
| self.base = base_model |
| self.yes_token_id = yes_token_id |
| self._dtype = next(base_model.parameters()).dtype |
|
|
| def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): |
| batch_size = input_ids.shape[0] |
| seq_len = input_ids.shape[1] |
| device = input_ids.device |
| min_val = torch.finfo(self._dtype).min |
|
|
| |
| |
| upper = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device).triu(diagonal=1) |
| causal = torch.zeros(1, 1, seq_len, seq_len, dtype=self._dtype, device=device) |
| causal = causal.masked_fill(upper.view(1, 1, seq_len, seq_len), min_val) |
| causal = causal.expand(batch_size, 1, seq_len, seq_len) |
|
|
| |
| pad = (1.0 - attention_mask.to(self._dtype)) * min_val |
| pad = pad.unsqueeze(1).unsqueeze(2) |
| pad = pad.expand(batch_size, 1, seq_len, seq_len) |
|
|
| full_mask = causal + pad |
|
|
| |
| hidden = self.base.model( |
| input_ids=input_ids, |
| attention_mask=full_mask, |
| )[0] |
|
|
| |
| last_pos = attention_mask.sum(dim=-1) - 1 |
| idx = last_pos.view(-1, 1, 1).expand(-1, 1, hidden.shape[-1]) |
| last_hidden = torch.gather(hidden, 1, idx).squeeze(1) |
|
|
| yes_logit = self.base.lm_head(last_hidden)[:, self.yes_token_id] |
| return yes_logit.unsqueeze(-1) |
|
|
|
|
| def run_export(): |
| from transformers import Qwen3ForCausalLM, AutoTokenizer |
| import torch.onnx as torch_onnx |
|
|
| print(f"Loading {MODEL_ID}...") |
| model = Qwen3ForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True, |
| attn_implementation="eager", |
| ).eval() |
|
|
| scorer = ZeRankScorerV2(model, YES_TOKEN_ID).eval() |
|
|
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
| |
| template = "<|im_start|>user\nQuery: {q}\nDocument: {d}\nRelevant:<|im_end|>\n<|im_start|>assistant\n" |
| pairs = [ |
| ("what is a panda?", "A panda is a large black-and-white bear."), |
| ("what is a cat?", "A cat is a small domesticated carnivorous mammal."), |
| ] |
| formatted = [template.format(q=q, d=d) for q, d in pairs] |
| enc = tok(formatted, padding=True, truncation=True, max_length=64, return_tensors="pt") |
| dummy_ids = enc["input_ids"] |
| dummy_mask = enc["attention_mask"] |
| print(f" Dummy batch shape: {dummy_ids.shape}") |
|
|
| |
| with torch.no_grad(): |
| out_batch = scorer(dummy_ids, dummy_mask) |
| out_single = scorer(dummy_ids[:1], dummy_mask[:1]) |
| assert abs(float(out_batch[0, 0]) - float(out_single[0, 0])) < 0.01, \ |
| f"Batch/single mismatch: {float(out_batch[0,0]):.3f} vs {float(out_single[0,0]):.3f}" |
| print(f" Batch consistency check PASS: {float(out_batch[0,0]):.3f} vs {float(out_single[0,0]):.3f}") |
|
|
| print(f"Exporting to {OUT_MODEL} ...") |
| with torch.no_grad(): |
| torch_onnx.export( |
| scorer, |
| (dummy_ids, dummy_mask), |
| str(OUT_MODEL), |
| input_names=["input_ids", "attention_mask"], |
| output_names=["logits"], |
| dynamic_axes={ |
| "input_ids": {0: "batch_size", 1: "sequence_length"}, |
| "attention_mask": {0: "batch_size", 1: "sequence_length"}, |
| "logits": {0: "batch_size"}, |
| }, |
| opset_version=18, |
| do_constant_folding=False, |
| ) |
|
|
| import onnx |
| from onnx.external_data_helper import convert_model_to_external_data |
| print(" Converting to external data format...") |
| m = onnx.load(str(OUT_MODEL)) |
| convert_model_to_external_data( |
| m, all_tensors_to_one_file=True, |
| location="model.onnx_data", size_threshold=1024, |
| ) |
| onnx.save(m, str(OUT_MODEL)) |
| print("Export complete:") |
| for f in sorted(OUT_DIR.iterdir()): |
| print(f" {f.name:40s} {f.stat().st_size / 1e6:.0f} MB") |
|
|
| del m, scorer, model, tok, enc, dummy_ids, dummy_mask |
| gc.collect() |
|
|
|
|
| def verify_batch(): |
| import onnxruntime as ort |
|
|
| print(f"\nVerifying batch > 1...") |
| sess = ort.InferenceSession(str(OUT_MODEL), providers=["CPUExecutionProvider"]) |
|
|
| from transformers import AutoTokenizer |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| template = "<|im_start|>user\nQuery: {q}\nDocument: {d}\nRelevant:<|im_end|>\n<|im_start|>assistant\n" |
|
|
| q = "what is a panda?" |
| docs = [ |
| "The giant panda is a bear species endemic to China.", |
| "The sky is blue.", |
| "panda is an animal", |
| ] |
|
|
| |
| single_scores = [] |
| for d in docs: |
| fmt = template.format(q=q, d=d) |
| enc = tok(fmt, return_tensors="np", truncation=True, max_length=256) |
| logit = sess.run(["logits"], { |
| "input_ids": enc["input_ids"].astype(np.int64), |
| "attention_mask": enc["attention_mask"].astype(np.int64), |
| })[0] |
| single_scores.append(float(logit[0, 0])) |
|
|
| |
| formatted = [template.format(q=q, d=d) for d in docs] |
| enc = tok(formatted, return_tensors="np", truncation=True, max_length=256, padding=True) |
| logits = sess.run(["logits"], { |
| "input_ids": enc["input_ids"].astype(np.int64), |
| "attention_mask": enc["attention_mask"].astype(np.int64), |
| })[0] |
| batch_scores = [float(logits[i, 0]) for i in range(len(docs))] |
|
|
| print(" Single vs batch scores:") |
| for d, s, b in zip(docs, single_scores, batch_scores): |
| diff = abs(s - b) |
| print(f" [{s:.3f} vs {b:.3f}] diff={diff:.4f} | {d[:50]}") |
| assert diff < 0.1, f"Mismatch too large: {diff}" |
| assert batch_scores[0] > batch_scores[1], "Panda should rank higher than sky" |
| print(" OK — batch scores match single, correct ranking") |
|
|
|
|
| if __name__ == "__main__": |
| if OUT_MODEL.exists(): |
| print(f"Model already exists at {OUT_MODEL}, skipping export.") |
| print("Delete it to re-export.") |
| else: |
| run_export() |
| gc.collect() |
|
|
| verify_batch() |
|
|
| print("\nNext steps:") |
| print(f" 1. Run stream_int8_v2.py to quantize INT8 from {OUT_MODEL}") |
| print(f" 2. Upload to HF: huggingface-cli upload cstr/zerank-1-small-ONNX {OUT_DIR}/ . --repo-type model") |
|
|