| import multiprocessing as mp |
| import random |
| import unittest |
|
|
| import torch |
|
|
| from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci |
| from sglang.test.runners import TEST_RERANK_QUERY_DOCS, HFRunner, SRTRunner |
| from sglang.test.test_utils import CustomTestCase, is_in_ci |
|
|
| |
|
|
|
|
| register_cuda_ci(est_time=100, suite="stage-b-test-small-1-gpu") |
| register_amd_ci(est_time=150, suite="stage-b-test-small-1-gpu-amd") |
|
|
| MODELS = [ |
| ("cross-encoder/ms-marco-MiniLM-L6-v2", 1, 1e-2), |
| ("BAAI/bge-reranker-v2-m3", 1, 1e-2), |
| ] |
| ATTENTION_BACKEND = ["torch_native", "triton"] |
|
|
| TORCH_DTYPES = [torch.float32] |
|
|
|
|
| class TestCrossEncoderModels(CustomTestCase): |
|
|
| @classmethod |
| def setUpClass(cls): |
| mp.set_start_method("spawn", force=True) |
|
|
| def assert_close_prefill_logits( |
| self, |
| prompts, |
| model_path, |
| tp_size, |
| torch_dtype, |
| score_tolerance, |
| attention_backend, |
| ) -> None: |
| with HFRunner( |
| model_path, |
| torch_dtype=torch_dtype, |
| model_type="cross_encoder", |
| ) as hf_runner: |
| hf_scores = hf_runner.forward(prompts).scores |
|
|
| with SRTRunner( |
| model_path, |
| tp_size=tp_size, |
| torch_dtype=torch_dtype, |
| model_type="cross_encoder", |
| attention_backend=attention_backend, |
| chunked_prefill_size=-1, |
| disable_radix_cache=True, |
| ) as srt_runner: |
| srt_scores = srt_runner.forward(prompts).scores |
|
|
| for i in range(len(srt_scores)): |
| score_difference = abs(hf_scores[i] - srt_scores[i]) |
|
|
| assert ( |
| score_difference < score_tolerance |
| ), "cross encoder scores are not all close" |
|
|
| def preprocess_prompts(self, prompt): |
| processed_prompts = [] |
| query = prompt["query"] |
| documents = prompt["documents"] |
| for document in documents: |
| processed_prompts.append([query, document]) |
|
|
| return processed_prompts |
|
|
| def test_prefill_logits(self): |
| models_to_test = MODELS |
|
|
| if is_in_ci(): |
| models_to_test = [random.choice(MODELS)] |
|
|
| for model, tp_size, prefill_tolerance in models_to_test: |
| for attention_backend in ATTENTION_BACKEND: |
| for queryDocs in TEST_RERANK_QUERY_DOCS: |
| prompts = self.preprocess_prompts(queryDocs) |
| for torch_dtype in TORCH_DTYPES: |
| self.assert_close_prefill_logits( |
| prompts, |
| model, |
| tp_size, |
| torch_dtype, |
| prefill_tolerance, |
| attention_backend, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|