| """Test original log probability alignment between SGLang and Hugging Face. |
| |
| This test suite verifies the correctness of the `origin_logprobs` output (temperature=1) |
| and the `logprobs` output (temperature=0.5) in SGLang by comparing it against |
| raw logit-based probabilities computed directly from a reference Hugging Face model. |
| |
| The test covers the following scenarios: |
| - Next-token prediction: Verifies that the log probability of the next token from |
| SGLang matches the Hugging Face model. |
| - Top-k logprobs: Ensures that the top-k original logprobs returned by SGLang are |
| consistent with Hugging Face outputs. |
| - Specified token IDs: Confirms that the original logprobs for specific token IDs |
| match the values computed from Hugging Face logits. |
| """ |
|
|
| import os |
| import random |
| import unittest |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| import sglang as sgl |
| from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci |
| from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST |
|
|
| register_cuda_ci(est_time=41, suite="stage-b-test-small-1-gpu") |
| register_amd_ci(est_time=60, suite="stage-b-test-small-1-gpu-amd") |
|
|
| |
| MODEL_ID = DEFAULT_SMALL_MODEL_NAME_FOR_TEST |
| PROMPTS = [ |
| "Hello, my name is", |
| "The future of AI is", |
| "The president of the United States is", |
| "The capital of France is ", |
| ] |
| TOP_LOGPROBS_NUM = 50 |
| NUM_RANDOM_TOKEN_IDS = 10 |
| RTOL = 0.20 |
| ATOL = 0.00 |
| |
|
|
| torch.manual_seed(1234) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(1234) |
| torch.backends.cuda.matmul.allow_tf32 = False |
| torch.backends.cudnn.allow_tf32 = False |
|
|
|
|
| class TestOriginalLogprob(unittest.TestCase): |
| def setUp(self): |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right") |
| self.hf_model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, torch_dtype=torch.float32, device_map="auto" |
| ) |
|
|
| |
| self.sampling_params = { |
| "temperature": 0.5, |
| "top_p": 1.0, |
| "top_k": 10, |
| "max_new_tokens": 1, |
| } |
|
|
| |
| |
| |
| |
| def assert_logprobs_block_equal( |
| self, |
| hf_log_probs: torch.Tensor, |
| token_log_probs: list, |
| top_log_probs: list, |
| ids_log_probs: list, |
| random_token_ids: list, |
| tag: str = "", |
| ): |
| vals, idxs, _ = zip(*token_log_probs) |
| sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32) |
| sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long) |
| hf_vals = hf_log_probs[sgl_idxs] |
|
|
| self.assertTrue( |
| torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL), |
| msg=f"[{tag}] token‑level mismatch at indices {sgl_idxs.tolist()}", |
| ) |
|
|
| hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1) |
|
|
| sgl_topk = torch.tensor( |
| [float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][ |
| :TOP_LOGPROBS_NUM |
| ], |
| dtype=torch.float32, |
| device=self.hf_model.device, |
| ) |
|
|
| k = min(hf_topk.numel(), sgl_topk.numel()) |
| self.assertTrue( |
| torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL), |
| msg=f"[{tag}] top‑k mismatch", |
| ) |
|
|
| indices = torch.tensor( |
| random_token_ids, dtype=torch.long, device=hf_log_probs.device |
| ) |
|
|
| hf_token_ids = hf_log_probs[indices] |
|
|
| sgl_token_ids = torch.tensor( |
| [v for v, _, _ in ids_log_probs[0]], |
| device=self.hf_model.device, |
| dtype=torch.float32, |
| ) |
| self.assertTrue( |
| torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL), |
| msg=f"[{tag}] token‑IDs mismatch", |
| ) |
|
|
| |
| max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item() |
| print(f"[{tag}] max|diff| token‑level = {max_diff:.4f}") |
|
|
| def test_logprob_match(self): |
| vocab_size = self.tokenizer.vocab_size |
|
|
| for env_val in ["True", "False"]: |
| with self.subTest(SGLANG_RETURN_ORIGINAL_LOGPROB=env_val): |
| os.environ["SGLANG_RETURN_ORIGINAL_LOGPROB"] = env_val |
|
|
| |
| sgl_engine = sgl.Engine( |
| model_path=MODEL_ID, |
| skip_tokenizer_init=True, |
| trust_remote_code=True, |
| mem_fraction_static=0.60, |
| ) |
|
|
| for prompt in PROMPTS: |
| random_token_ids = sorted( |
| random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS) |
| ) |
|
|
| enc = self.tokenizer(prompt, return_tensors="pt") |
| input_ids = enc["input_ids"].to(self.hf_model.device) |
| attn_mask = enc["attention_mask"].to(self.hf_model.device) |
|
|
| with torch.inference_mode(): |
| hf_out = self.hf_model( |
| input_ids=input_ids, |
| attention_mask=attn_mask, |
| return_dict=True, |
| ) |
| logits = hf_out.logits[:, -1, :] |
| hf_log_probs = F.log_softmax( |
| logits.float() / self.sampling_params["temperature"], dim=-1 |
| )[0] |
| hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0] |
|
|
| outputs = sgl_engine.generate( |
| input_ids=input_ids[0].tolist(), |
| sampling_params=self.sampling_params, |
| return_logprob=True, |
| top_logprobs_num=TOP_LOGPROBS_NUM, |
| token_ids_logprob=random_token_ids, |
| ) |
|
|
| if isinstance(outputs, list): |
| outputs = outputs[0] |
| meta = outputs["meta_info"] |
|
|
| |
| if env_val.lower() == "true": |
| self.assert_logprobs_block_equal( |
| hf_log_probs=hf_original_log_probs, |
| token_log_probs=meta["output_token_logprobs"], |
| top_log_probs=meta["output_top_logprobs"], |
| ids_log_probs=meta["output_token_ids_logprobs"], |
| random_token_ids=random_token_ids, |
| tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})", |
| ) |
| else: |
| |
| self.assert_logprobs_block_equal( |
| hf_log_probs=hf_log_probs, |
| token_log_probs=meta["output_token_logprobs"], |
| top_log_probs=meta["output_top_logprobs"], |
| ids_log_probs=meta["output_token_ids_logprobs"], |
| random_token_ids=random_token_ids, |
| tag=f"logprobs SGLang vs HF: {prompt} ({env_val})", |
| ) |
| sgl_engine.shutdown() |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|