Hanrui / sglang /test /registered /openai_server /basic /test_serving_rerank.py
Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
import asyncio
import unittest
from unittest.mock import Mock
from sglang.srt.entrypoints.openai.protocol import V1RerankReqInput
from sglang.srt.managers.tokenizer_manager_multiitem_mixin import ScoreResult
from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci
# Keep consistent with other openai_server/basic unit tests.
register_cuda_ci(est_time=10, suite="stage-b-test-large-1-gpu")
register_amd_ci(est_time=10, suite="stage-b-test-small-1-gpu-amd")
try:
from sglang.srt.entrypoints.openai.serving_rerank import (
OpenAIServingRerank,
_is_qwen3_reranker_template,
_qwen3_rerank_score,
_render_jinja_chat_template,
)
except ModuleNotFoundError as e:
# Some minimal environments used for unit tests may not have FastAPI/torch installed.
# Skip this test in that case.
if e.name in ("fastapi", "torch"):
OpenAIServingRerank = None # type: ignore[assignment]
else:
raise
class _DummyModelConfig:
# Keep consistent with TokenizerManager.model_config usage
is_generation = False
class _DummyTokenizer:
chat_template = ""
class _DummyTokenizerManager:
# Minimal surface required by OpenAIServingBase/OpenAIServingRerank
server_args = object()
model_config = _DummyModelConfig()
tokenizer = _DummyTokenizer()
async def generate_request(self, *_args, **_kwargs):
raise AssertionError("generate_request should not be called in this unit test")
@unittest.skipIf(OpenAIServingRerank is None, "fastapi/torch is not installed")
class TestOpenAIServingRerankUnit(unittest.TestCase):
def setUp(self):
self.handler = OpenAIServingRerank(_DummyTokenizerManager())
def test_convert_to_internal_request_cross_encoder_pairs(self):
req = V1RerankReqInput(
query="q",
documents=["doc-a", "doc-b"],
instruct="Retrieve semantically similar text.",
)
adapted, processed = self.handler._convert_to_internal_request(req)
# Avoid importing EmbeddingReqInput (requires torch). Use duck-typing checks instead.
self.assertTrue(hasattr(adapted, "is_cross_encoder_request"))
self.assertTrue(adapted.is_cross_encoder_request)
self.assertEqual(getattr(adapted, "text"), [["q", "doc-a"], ["q", "doc-b"]])
self.assertEqual(processed, req)
def test_convert_to_internal_request_qwen3_template_returns_request(self):
tm = _DummyTokenizerManager()
tm.tokenizer.chat_template = (
'... Note that the answer can only be "yes" or "no". ...'
)
handler = OpenAIServingRerank(tm)
req = V1RerankReqInput(query="q", documents=["d1"])
adapted, processed = handler._convert_to_internal_request(req)
self.assertIs(adapted, req)
self.assertIs(processed, req)
def test_build_rerank_response_embedding_list_uses_first_scalar(self):
req = V1RerankReqInput(
query="q",
documents=["doc-a", "doc-b"],
return_documents=True,
)
# Two results with embedding as list, should coerce embedding[0] to float.
# Also verifies sorting (doc-b > doc-a).
ret = [
{"embedding": [0.1, 0.2], "meta_info": {"id": "a"}},
{"embedding": [0.9, -1.0], "meta_info": {"id": "b"}},
]
res = self.handler._build_rerank_response(ret, req)
self.assertEqual(len(res), 2)
# Sorted descending by score, so doc-b first.
self.assertEqual(res[0].document, "doc-b")
self.assertEqual(res[0].index, 1)
self.assertAlmostEqual(res[0].score, 0.9)
self.assertEqual(res[0].meta_info, {"id": "b"})
self.assertEqual(res[1].document, "doc-a")
self.assertEqual(res[1].index, 0)
self.assertAlmostEqual(res[1].score, 0.1)
self.assertEqual(res[1].meta_info, {"id": "a"})
def test_build_rerank_response_float_list(self):
req = V1RerankReqInput(
query="q", documents=["a", "b", "c"], return_documents=True
)
scores = [0.2, 0.9, 0.1]
res = self.handler._build_rerank_response(scores, req)
self.assertEqual([r.document for r in res], ["b", "a", "c"])
self.assertEqual([r.index for r in res], [1, 0, 2])
self.assertAlmostEqual(res[0].score, 0.9)
self.assertAlmostEqual(res[1].score, 0.2)
self.assertAlmostEqual(res[2].score, 0.1)
def test_helper_is_qwen3_reranker_template(self):
self.assertTrue(
_is_qwen3_reranker_template(
'Note that the answer can only be "yes" or "no".'
)
)
self.assertFalse(_is_qwen3_reranker_template("plain template"))
def test_helper_qwen3_rerank_score(self):
self.assertAlmostEqual(_qwen3_rerank_score(0.9, 0.1), 0.9)
self.assertAlmostEqual(_qwen3_rerank_score(0.0, 0.0), 0.0)
def test_helper_render_jinja_chat_template(self):
# Skip if jinja2 isn't installed in this environment.
try:
import jinja2 # noqa: F401
except ModuleNotFoundError:
self.skipTest("jinja2 is not installed")
tpl = "{{ instruct | default('DEF') }}|{{ messages[0]['content'] }}|{{ messages[1]['content'] }}"
self.assertEqual(
_render_jinja_chat_template(tpl, query="Q", document="D", instruct=None),
"DEF|Q|D",
)
self.assertEqual(
_render_jinja_chat_template(tpl, query="Q", document="D", instruct="I"),
"I|Q|D",
)
def test_handle_non_streaming_request_qwen3_path_uses_score_prompts(self):
class _TM(_DummyTokenizerManager):
def __init__(self):
self.server_args = object()
self.model_config = Mock()
self.model_config.is_generation = True
self.model_config.model_path = "qwen/qwen3"
self.tokenizer = Mock()
self.tokenizer.chat_template = (
'Note that the answer can only be "yes" or "no". '
"{{ messages[0]['content'] }} {{ messages[1]['content'] }}"
)
async def score_prompts(
self, prompts, label_token_ids, apply_softmax, request
):
# Return [p_yes, p_no] for each prompt
assert len(prompts) == 2
assert label_token_ids and len(label_token_ids) == 2
return ScoreResult(scores=[[0.9, 0.1], [0.2, 0.8]], prompt_tokens=42)
handler = OpenAIServingRerank(_TM())
req = V1RerankReqInput(query="q", documents=["d1", "d2"], return_documents=True)
adapted, _ = handler._convert_to_internal_request(req)
raw_request = Mock()
res = asyncio.run(
handler._handle_non_streaming_request(adapted, req, raw_request)
)
self.assertEqual([r.document for r in res], ["d1", "d2"])
self.assertAlmostEqual(res[0].score, 0.9 / (0.9 + 0.1))
self.assertAlmostEqual(res[1].score, 0.2 / (0.2 + 0.8))
def test_build_rerank_response_return_documents_false(self):
"""Test that document field is None when return_documents=False"""
req = V1RerankReqInput(
query="q", documents=["a", "b", "c"], return_documents=False
)
scores = [0.2, 0.9, 0.1]
res = self.handler._build_rerank_response(scores, req)
# All documents should be None
self.assertEqual([r.document for r in res], [None, None, None])
# But scores and indices should still be correct
self.assertEqual([r.index for r in res], [1, 0, 2])
self.assertAlmostEqual(res[0].score, 0.9)
def test_build_rerank_response_top_n(self):
"""Test that top_n limits the number of returned results"""
req = V1RerankReqInput(
query="q", documents=["a", "b", "c"], return_documents=True, top_n=2
)
scores = [0.2, 0.9, 0.1]
res = self.handler._build_rerank_response(scores, req)
# Should only return top 2 results
self.assertEqual(len(res), 2)
self.assertEqual([r.document for r in res], ["b", "a"])
self.assertEqual([r.index for r in res], [1, 0])
self.assertAlmostEqual(res[0].score, 0.9)
self.assertAlmostEqual(res[1].score, 0.2)
def test_build_rerank_response_top_n_greater_than_total(self):
"""Test that top_n greater than total documents returns all documents"""
req = V1RerankReqInput(
query="q", documents=["a", "b"], return_documents=True, top_n=10
)
scores = [0.2, 0.9]
res = self.handler._build_rerank_response(scores, req)
# Should return all 2 documents even though top_n=10
self.assertEqual(len(res), 2)
self.assertEqual([r.document for r in res], ["b", "a"])
def test_build_rerank_response_top_n_with_return_documents_false(self):
"""Test top_n works correctly with return_documents=False"""
req = V1RerankReqInput(
query="q", documents=["a", "b", "c"], return_documents=False, top_n=1
)
scores = [0.2, 0.9, 0.1]
res = self.handler._build_rerank_response(scores, req)
# Should only return top 1 result, and document should be None
self.assertEqual(len(res), 1)
self.assertIsNone(res[0].document)
self.assertEqual(res[0].index, 1)
self.assertAlmostEqual(res[0].score, 0.9)
def test_handle_vl_reranker_request(self):
"""Test the Qwen3-VL reranker path with mocked logprobs."""
import math
# Mock tokenizer manager that supports generate_request
class _AsyncGen:
def __init__(self, val):
self.val = val
def __aiter__(self):
return self
async def __anext__(self):
return self.val
class _TM(_DummyTokenizerManager):
def __init__(self):
self.server_args = object()
self.model_config = Mock()
self.model_config.is_generation = True
self.model_config.model_path = "qwen/qwen3-vl"
self.tokenizer = Mock()
# Mock VL template detection
self.tokenizer.chat_template = (
"{% for x in query %}{{ x.text }}{% endfor %}"
"{% for x in document %}{{ x.text }}{% endfor %}"
'answer can only be "yes" or "no" <|vision_start|>'
)
async def generate_request(self, req, _raw):
# Return logprobs for yes/no
# Mock logprobs: P(yes) > P(no) for first doc, P(no) > P(yes) for second
if not hasattr(self, "call_count"):
self.call_count = 0
if self.call_count == 0:
# First doc: yes is likely
yes_logprob = math.log(0.8)
no_logprob = math.log(0.2)
else:
# Second doc: no is likely
yes_logprob = math.log(0.3)
no_logprob = math.log(0.7)
self.call_count += 1
# Qwen3 token IDs: YES=9693, NO=2152
top_logprobs = [
(yes_logprob, 9693, "yes"),
(no_logprob, 2152, "no"),
]
# The rerank handler checks output_top_logprobs[0] for the first generated token
meta_info = {"output_top_logprobs": [top_logprobs]}
yield {"meta_info": meta_info, "embedding": None}
handler = OpenAIServingRerank(_TM())
req = V1RerankReqInput(
query="query", documents=["doc1", "doc2"], return_documents=True
)
# Force VL path is handled by detection logic inside handler
# We mocked chat_template to satisfy _is_qwen3_vl_reranker_template
raw_request = Mock()
res = asyncio.run(handler._handle_non_streaming_request(req, req, raw_request))
self.assertEqual(len(res), 2)
# First doc should have higher score
self.assertEqual(res[0].document, "doc1")
self.assertAlmostEqual(res[0].score, 0.8) # 0.8 / (0.8+0.2) = 0.8
self.assertEqual(res[1].document, "doc2")
self.assertAlmostEqual(res[1].score, 0.3) # 0.3 / (0.3+0.7) = 0.3
if __name__ == "__main__":
unittest.main(verbosity=2)