| import hashlib |
| import json |
| import os |
| import shutil |
| import subprocess |
| import sys |
| import tempfile |
| import unittest |
| import warnings |
| from contextlib import nullcontext |
| from io import StringIO |
|
|
| import requests |
| from huggingface_hub import snapshot_download |
|
|
| from sglang.srt.utils import kill_process_tree |
| from sglang.srt.utils.model_file_verifier import ( |
| IntegrityError, |
| compute_sha256, |
| generate_checksums, |
| verify, |
| ) |
| from sglang.test.ci.ci_register import register_cuda_ci |
| from sglang.test.test_utils import ( |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| DEFAULT_URL_FOR_TEST, |
| popen_launch_server, |
| ) |
|
|
| |
| register_cuda_ci(est_time=120, suite="nightly-1-gpu", nightly=True) |
|
|
| MODEL_NAME = "Qwen/Qwen3-0.6B" |
|
|
|
|
| |
|
|
|
|
| class _FakeModelTestCase(unittest.TestCase): |
|
|
| FAKE_FILES = { |
| "model.safetensors": b"fake safetensors content " * 100, |
| "config.json": b'{"model_type": "llama"}', |
| "tokenizer.json": b'{"version": "1.0"}', |
| } |
|
|
| def setUp(self): |
| self.test_dir = tempfile.mkdtemp() |
| for filename, content in self.FAKE_FILES.items(): |
| _create_test_file(self.test_dir, filename, content) |
|
|
| def tearDown(self): |
| shutil.rmtree(self.test_dir, ignore_errors=True) |
|
|
|
|
| class _RealModelTestCase(unittest.TestCase): |
|
|
| @classmethod |
| def setUpClass(cls): |
| cls.original_model_path = snapshot_download(MODEL_NAME) |
|
|
| def setUp(self): |
| self.test_dir = tempfile.mkdtemp() |
| shutil.copytree(self.original_model_path, self.test_dir, dirs_exist_ok=True) |
|
|
| def tearDown(self): |
| shutil.rmtree(self.test_dir, ignore_errors=True) |
|
|
|
|
| |
|
|
|
|
| class TestModelFileVerifier(_FakeModelTestCase): |
|
|
| def test_detect_bit_rot(self): |
| checksums_file = os.path.join(self.test_dir, "checksums.json") |
| generate_checksums(source=self.test_dir, output_path=checksums_file) |
|
|
| target_file = os.path.join(self.test_dir, "model.safetensors") |
| _flip_bit_in_file(target_file, byte_offset=50, bit_position=3) |
|
|
| with self.assertRaises(IntegrityError) as ctx: |
| verify(model_path=self.test_dir, checksums_source=checksums_file) |
|
|
| self.assertIn("model.safetensors", str(ctx.exception)) |
| self.assertIn("mismatch", str(ctx.exception).lower()) |
|
|
| def test_detect_missing_file(self): |
| checksums_file = os.path.join(self.test_dir, "checksums.json") |
| generate_checksums(source=self.test_dir, output_path=checksums_file) |
|
|
| os.remove(os.path.join(self.test_dir, "config.json")) |
|
|
| with self.assertRaises(IntegrityError) as ctx: |
| verify(model_path=self.test_dir, checksums_source=checksums_file) |
|
|
| self.assertIn("config.json", str(ctx.exception)) |
|
|
| def test_compute_sha256(self): |
| test_file = os.path.join(self.test_dir, "test.bin") |
| content = b"hello world" |
| with open(test_file, "wb") as f: |
| f.write(content) |
|
|
| result = compute_sha256(file_path=test_file) |
| expected = hashlib.sha256(content).hexdigest() |
| self.assertEqual(result, expected) |
|
|
| def test_parallel_checksum_computation(self): |
| for i in range(10): |
| _create_test_file( |
| self.test_dir, f"shard_{i}.safetensors", f"content_{i}".encode() * 1000 |
| ) |
|
|
| checksums_file = os.path.join(self.test_dir, "checksums.json") |
| result = generate_checksums( |
| source=self.test_dir, output_path=checksums_file, max_workers=4 |
| ) |
|
|
| self.assertGreaterEqual(len(result.files), 10) |
|
|
| def test_generated_json_snapshot(self): |
| checksums_file = os.path.join(self.test_dir, "checksums.json") |
| generate_checksums(source=self.test_dir, output_path=checksums_file) |
|
|
| with open(checksums_file) as f: |
| data = json.load(f) |
|
|
| expected = { |
| "files": { |
| "config.json": { |
| "sha256": "81dddc8c379baae137d99d24c5fa081d3a5ce52b6a221ddc22fe364711f8beaf", |
| "size": 23, |
| }, |
| "model.safetensors": { |
| "sha256": "eb0c73a48a89fefb6b68dd41af830d75610c885135eac99139373b04705d05f3", |
| "size": 2500, |
| }, |
| "tokenizer.json": { |
| "sha256": "4e3043229142b64d998563bc543ce034e0a2251af5d404995e3afcb8ce8850df", |
| "size": 18, |
| }, |
| } |
| } |
| self.assertEqual(data, expected) |
|
|
| def test_legacy_checksums_format_deprecated(self): |
| legacy_data = { |
| "checksums": { |
| "model.safetensors": "eb0c73a48a89fefb6b68dd41af830d75610c885135eac99139373b04705d05f3", |
| "config.json": "81dddc8c379baae137d99d24c5fa081d3a5ce52b6a221ddc22fe364711f8beaf", |
| "tokenizer.json": "4e3043229142b64d998563bc543ce034e0a2251af5d404995e3afcb8ce8850df", |
| } |
| } |
| legacy_file = os.path.join(self.test_dir, "legacy_checksums.json") |
| with open(legacy_file, "w") as f: |
| json.dump(legacy_data, f) |
|
|
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter("always") |
| verify(model_path=self.test_dir, checksums_source=legacy_file) |
| self.assertEqual(len(w), 1) |
| self.assertTrue(issubclass(w[0].category, DeprecationWarning)) |
| self.assertIn("deprecated", str(w[0].message).lower()) |
|
|
|
|
| |
|
|
|
|
| class TestModelFileVerifierCLI(_FakeModelTestCase): |
|
|
| def test_cli_generate(self): |
| checksums_file = os.path.join(self.test_dir, "checksums.json") |
| result = subprocess.run( |
| [ |
| sys.executable, |
| "-m", |
| "sglang.srt.utils.model_file_verifier", |
| "generate", |
| "--model-path", |
| self.test_dir, |
| "--model-checksum", |
| checksums_file, |
| ], |
| capture_output=True, |
| text=True, |
| ) |
| self.assertEqual(result.returncode, 0, f"stderr: {result.stderr}") |
| self.assertTrue(os.path.exists(checksums_file)) |
|
|
| with open(checksums_file) as f: |
| data = json.load(f) |
| self.assertIn("files", data) |
| self.assertEqual(len(data["files"]), 3) |
|
|
| def test_cli_verify_success(self): |
| checksums_file = os.path.join(self.test_dir, "checksums.json") |
| generate_checksums(source=self.test_dir, output_path=checksums_file) |
|
|
| result = subprocess.run( |
| [ |
| sys.executable, |
| "-m", |
| "sglang.srt.utils.model_file_verifier", |
| "verify", |
| "--model-path", |
| self.test_dir, |
| "--model-checksum", |
| checksums_file, |
| ], |
| capture_output=True, |
| text=True, |
| ) |
| self.assertEqual(result.returncode, 0, f"stderr: {result.stderr}") |
| self.assertIn("verified successfully", result.stdout) |
|
|
| def test_cli_verify_fails_on_corruption(self): |
| checksums_file = os.path.join(self.test_dir, "checksums.json") |
| generate_checksums(source=self.test_dir, output_path=checksums_file) |
|
|
| target_file = os.path.join(self.test_dir, "model.safetensors") |
| _flip_bit_in_file(target_file, byte_offset=50, bit_position=3) |
|
|
| result = subprocess.run( |
| [ |
| sys.executable, |
| "-m", |
| "sglang.srt.utils.model_file_verifier", |
| "verify", |
| "--model-path", |
| self.test_dir, |
| "--model-checksum", |
| checksums_file, |
| ], |
| capture_output=True, |
| text=True, |
| ) |
| self.assertNotEqual(result.returncode, 0) |
| combined = result.stdout + result.stderr |
| self.assertTrue( |
| "IntegrityError" in combined or "mismatch" in combined.lower(), |
| f"Expected integrity error, got: {combined}", |
| ) |
|
|
|
|
| |
|
|
|
|
| class TestModelFileVerifierHF(_RealModelTestCase): |
|
|
| def test_generate_checksums_from_hf(self): |
| checksums_file = os.path.join(self.test_dir, "checksums.json") |
| result = generate_checksums(source=MODEL_NAME, output_path=checksums_file) |
|
|
| self.assertTrue(os.path.exists(checksums_file)) |
| self.assertGreater(len(result.files), 0) |
| for filename, file_info in result.files.items(): |
| self.assertEqual(len(file_info.sha256), 64) |
|
|
| def test_verify_with_hf_checksums_source(self): |
| verify(model_path=self.test_dir, checksums_source=MODEL_NAME) |
|
|
|
|
| |
|
|
|
|
| class TestModelFileVerifierWithRealModel(_RealModelTestCase): |
|
|
| def _run_server_test(self, *, corrupt_weights: bool, use_hf_checksum: bool): |
| if use_hf_checksum: |
| checksum_arg = MODEL_NAME |
| else: |
| checksums_file = os.path.join(self.test_dir, "checksums.json") |
| generate_checksums(source=self.test_dir, output_path=checksums_file) |
| checksum_arg = checksums_file |
|
|
| corrupted_file = None |
| if corrupt_weights: |
| safetensors_files = [ |
| f for f in os.listdir(self.test_dir) if f.endswith(".safetensors") |
| ] |
| self.assertTrue(len(safetensors_files) > 0, "No safetensors files found") |
| corrupted_file = safetensors_files[0] |
| _flip_bit_in_file(os.path.join(self.test_dir, corrupted_file)) |
|
|
| stdout_io, stderr_io = StringIO(), StringIO() |
| ctx = self.assertRaises(Exception) if corrupt_weights else nullcontext() |
| with ctx: |
| process = popen_launch_server( |
| model=self.test_dir, |
| base_url=DEFAULT_URL_FOR_TEST, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=["--model-checksum", checksum_arg], |
| return_stdout_stderr=(stdout_io, stderr_io), |
| ) |
|
|
| if corrupt_weights: |
| output = stdout_io.getvalue() + stderr_io.getvalue() |
| self.assertIn(corrupted_file, output) |
| self.assertIn("mismatch", output.lower()) |
| else: |
| try: |
| response = requests.post( |
| f"{DEFAULT_URL_FOR_TEST}/generate", |
| json={"text": "Hello", "sampling_params": {"max_new_tokens": 8}}, |
| ) |
| self.assertEqual(response.status_code, 200) |
| self.assertIn("text", response.json()) |
| finally: |
| kill_process_tree(process.pid) |
|
|
| def test_server_launch_with_checksum_intact(self): |
| self._run_server_test(corrupt_weights=False, use_hf_checksum=False) |
|
|
| def test_server_launch_fails_with_corrupted_weights(self): |
| self._run_server_test(corrupt_weights=True, use_hf_checksum=False) |
|
|
| def test_server_launch_with_hf_checksum_intact(self): |
| self._run_server_test(corrupt_weights=False, use_hf_checksum=True) |
|
|
| def test_server_launch_with_hf_checksum_corrupted(self): |
| self._run_server_test(corrupt_weights=True, use_hf_checksum=True) |
|
|
|
|
| |
|
|
|
|
| def _create_test_file(directory: str, filename: str, content: bytes) -> str: |
| path = os.path.join(directory, filename) |
| with open(path, "wb") as f: |
| f.write(content) |
| return path |
|
|
|
|
| def _flip_bit_in_file(file_path: str, byte_offset: int = 100, bit_position: int = 0): |
| file_size = os.path.getsize(file_path) |
| assert ( |
| byte_offset < file_size |
| ), f"byte_offset {byte_offset} >= file_size {file_size}" |
|
|
| with open(file_path, "r+b") as f: |
| f.seek(byte_offset) |
| original_byte = f.read(1)[0] |
| f.seek(byte_offset) |
| f.write(bytes([original_byte ^ (1 << bit_position)])) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|