#!/usr/bin/env python3 """ Re-export zerank-1-small with dynamic batch support. Key change from v1: ZeRankScorerV2 builds the 4D causal+padding attention mask explicitly using input_ids.shape[0] (dynamic). This makes the batch dimension symbolic in the ONNX graph — batch > 1 works correctly. Also bakes the Qwen3 chat template into the expected input format: "<|im_start|>user\\nQuery: {q}\\nDocument: {d}\\nRelevant:<|im_end|>\\n<|im_start|>assistant\\n" Tokenize the formatted string as a SINGLE sequence (not a pair) in fastembed. Output: /private/tmp/zerank_export/zerank_onnx_v2/model.onnx + model.onnx_data (FP16) (INT8/INT4 re-quantization: run stream_int8.py and export_int4.py after this) """ import gc from pathlib import Path import numpy as np import torch import torch.nn as nn MODEL_ID = "zeroentropy/zerank-1-small" YES_TOKEN_ID = 9454 OUT_DIR = Path("/private/tmp/zerank_export/zerank_onnx_v2") OUT_MODEL = OUT_DIR / "model.onnx" OUT_DIR.mkdir(parents=True, exist_ok=True) class ZeRankScorerV2(nn.Module): """ Wraps Qwen3ForCausalLM + last-token Yes-logit extraction. Difference from V1: builds 4D causal+padding mask explicitly so the batch dimension is dynamic in the ONNX graph (V1 had it hardcoded to 1). Input: input_ids [batch, seq] — pre-formatted with chat template attention_mask [batch, seq] — 1 for real tokens, 0 for padding Output: logits [batch, 1] — raw Yes-token logit, higher = more relevant """ def __init__(self, base_model, yes_token_id: int): super().__init__() self.base = base_model self.yes_token_id = yes_token_id self._dtype = next(base_model.parameters()).dtype def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): batch_size = input_ids.shape[0] seq_len = input_ids.shape[1] device = input_ids.device min_val = torch.finfo(self._dtype).min # Causal mask: upper-triangular = min_val, lower-triangular = 0 # Shape [1, 1, seq, seq] → expand to [batch, 1, seq, seq] upper = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device).triu(diagonal=1) causal = torch.zeros(1, 1, seq_len, seq_len, dtype=self._dtype, device=device) causal = causal.masked_fill(upper.view(1, 1, seq_len, seq_len), min_val) causal = causal.expand(batch_size, 1, seq_len, seq_len) # Padding mask: positions with attention_mask=0 get min_val pad = (1.0 - attention_mask.to(self._dtype)) * min_val # [batch, seq] pad = pad.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq] pad = pad.expand(batch_size, 1, seq_len, seq_len) full_mask = causal + pad # Transformer body → [batch, seq, hidden] hidden = self.base.model( input_ids=input_ids, attention_mask=full_mask, )[0] # Gather at last real-token position: sum(mask) - 1 last_pos = attention_mask.sum(dim=-1) - 1 # [batch] idx = last_pos.view(-1, 1, 1).expand(-1, 1, hidden.shape[-1]) last_hidden = torch.gather(hidden, 1, idx).squeeze(1) # [batch, hidden] yes_logit = self.base.lm_head(last_hidden)[:, self.yes_token_id] # [batch] return yes_logit.unsqueeze(-1) # [batch, 1] def run_export(): from transformers import Qwen3ForCausalLM, AutoTokenizer import torch.onnx as torch_onnx print(f"Loading {MODEL_ID}...") model = Qwen3ForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, low_cpu_mem_usage=True, attn_implementation="eager", ).eval() scorer = ZeRankScorerV2(model, YES_TOKEN_ID).eval() tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) # Dummy batch=2 — forces dynamic batch to trace correctly template = "<|im_start|>user\nQuery: {q}\nDocument: {d}\nRelevant:<|im_end|>\n<|im_start|>assistant\n" pairs = [ ("what is a panda?", "A panda is a large black-and-white bear."), ("what is a cat?", "A cat is a small domesticated carnivorous mammal."), ] formatted = [template.format(q=q, d=d) for q, d in pairs] enc = tok(formatted, padding=True, truncation=True, max_length=64, return_tensors="pt") dummy_ids = enc["input_ids"] dummy_mask = enc["attention_mask"] print(f" Dummy batch shape: {dummy_ids.shape}") # Verify correct batch behaviour before exporting with torch.no_grad(): out_batch = scorer(dummy_ids, dummy_mask) out_single = scorer(dummy_ids[:1], dummy_mask[:1]) assert abs(float(out_batch[0, 0]) - float(out_single[0, 0])) < 0.01, \ f"Batch/single mismatch: {float(out_batch[0,0]):.3f} vs {float(out_single[0,0]):.3f}" print(f" Batch consistency check PASS: {float(out_batch[0,0]):.3f} vs {float(out_single[0,0]):.3f}") print(f"Exporting to {OUT_MODEL} ...") with torch.no_grad(): torch_onnx.export( scorer, (dummy_ids, dummy_mask), str(OUT_MODEL), input_names=["input_ids", "attention_mask"], output_names=["logits"], dynamic_axes={ "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, "logits": {0: "batch_size"}, }, opset_version=18, do_constant_folding=False, ) import onnx from onnx.external_data_helper import convert_model_to_external_data print(" Converting to external data format...") m = onnx.load(str(OUT_MODEL)) convert_model_to_external_data( m, all_tensors_to_one_file=True, location="model.onnx_data", size_threshold=1024, ) onnx.save(m, str(OUT_MODEL)) print("Export complete:") for f in sorted(OUT_DIR.iterdir()): print(f" {f.name:40s} {f.stat().st_size / 1e6:.0f} MB") del m, scorer, model, tok, enc, dummy_ids, dummy_mask gc.collect() def verify_batch(): import onnxruntime as ort print(f"\nVerifying batch > 1...") sess = ort.InferenceSession(str(OUT_MODEL), providers=["CPUExecutionProvider"]) from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) template = "<|im_start|>user\nQuery: {q}\nDocument: {d}\nRelevant:<|im_end|>\n<|im_start|>assistant\n" q = "what is a panda?" docs = [ "The giant panda is a bear species endemic to China.", "The sky is blue.", "panda is an animal", ] # Single inference single_scores = [] for d in docs: fmt = template.format(q=q, d=d) enc = tok(fmt, return_tensors="np", truncation=True, max_length=256) logit = sess.run(["logits"], { "input_ids": enc["input_ids"].astype(np.int64), "attention_mask": enc["attention_mask"].astype(np.int64), })[0] single_scores.append(float(logit[0, 0])) # Batch inference formatted = [template.format(q=q, d=d) for d in docs] enc = tok(formatted, return_tensors="np", truncation=True, max_length=256, padding=True) logits = sess.run(["logits"], { "input_ids": enc["input_ids"].astype(np.int64), "attention_mask": enc["attention_mask"].astype(np.int64), })[0] batch_scores = [float(logits[i, 0]) for i in range(len(docs))] print(" Single vs batch scores:") for d, s, b in zip(docs, single_scores, batch_scores): diff = abs(s - b) print(f" [{s:.3f} vs {b:.3f}] diff={diff:.4f} | {d[:50]}") assert diff < 0.1, f"Mismatch too large: {diff}" assert batch_scores[0] > batch_scores[1], "Panda should rank higher than sky" print(" OK — batch scores match single, correct ranking") if __name__ == "__main__": if OUT_MODEL.exists(): print(f"Model already exists at {OUT_MODEL}, skipping export.") print("Delete it to re-export.") else: run_export() gc.collect() verify_batch() print("\nNext steps:") print(f" 1. Run stream_int8_v2.py to quantize INT8 from {OUT_MODEL}") print(f" 2. Upload to HF: huggingface-cli upload cstr/zerank-1-small-ONNX {OUT_DIR}/ . --repo-type model")