| """ |
| Usage: |
| # replay from a folder |
| python3 replay_request_dump.py --file-number 100 --parallel 512 --input-folder /data/lianmin/sglang_request_dump/grok-mini-0220-engine-5756f8f94-28bm6/ |
| |
| # replay from a single file |
| python3 replay_request_dump.py --parallel 512 --input-file /data/sglang_crash_dump/memx-cti-34-sr1.xpop.twttr.net/crash_dump_2025-06-04_20-13-18.pkl |
| """ |
|
|
| import argparse |
| import glob |
| import json |
| import pickle |
| import time |
| from concurrent.futures import ThreadPoolExecutor |
| from dataclasses import asdict |
| from datetime import datetime |
|
|
| import requests |
|
|
| from sglang.benchmark.utils import set_ulimit |
| from sglang.utils import get_exception_traceback |
|
|
|
|
| def normalize_mm_data_item(item): |
| if isinstance(item, dict) and "url" in item: |
| return item["url"] |
| return item |
|
|
|
|
| def normalize_mm_data(data): |
| if data is None: |
| return None |
| if isinstance(data, list): |
| return [ |
| ( |
| [normalize_mm_data_item(item) for item in sublist] |
| if isinstance(sublist, list) |
| else normalize_mm_data_item(sublist) |
| ) |
| for sublist in data |
| ] |
| return normalize_mm_data_item(data) |
|
|
|
|
| def normalize_request_data(json_data): |
| """Normalize multimodal fields in request data for replay compatibility.""" |
| for field in ["image_data", "video_data", "audio_data"]: |
| if field in json_data and json_data[field] is not None: |
| json_data[field] = normalize_mm_data(json_data[field]) |
| return json_data |
|
|
|
|
| def read_records(files): |
| records = [] |
| for f in files: |
| tmp = pickle.load(open(f, "rb")) |
| if isinstance(tmp, dict) and "requests" in tmp: |
| records.extend(tmp["requests"]) |
| else: |
| records.extend(tmp) |
|
|
| return records |
|
|
|
|
| def run_one_request_internal(record): |
| req, output, replay_init_time, start_time, end_time, idx = record |
| time.sleep(max(0, (start_time - (time.time() - replay_init_time)) / args.speed)) |
|
|
| if "completion_tokens" in output.get("meta_info", {}): |
| recorded_completion_tokens = output["meta_info"]["completion_tokens"] |
| else: |
| recorded_completion_tokens = "" |
|
|
| json_data = normalize_request_data(asdict(req)) |
| stream = json_data["stream"] |
|
|
| if args.ignore_eos: |
| json_data["sampling_params"]["ignore_eos"] = True |
| if recorded_completion_tokens: |
| json_data["sampling_params"]["max_new_tokens"] = recorded_completion_tokens |
|
|
| response = requests.post( |
| f"http://{args.host}:{args.port}/generate", |
| json=json_data, |
| stream=stream, |
| ) |
|
|
| if stream: |
| for chunk in response.iter_lines(decode_unicode=False): |
| chunk = chunk.decode("utf-8") |
| if chunk and chunk.startswith("data:"): |
| if chunk == "data: [DONE]": |
| break |
| ret = json.loads(chunk[5:].strip("\n")) |
| else: |
| ret = response.json() |
|
|
| prompt_tokens = ret["meta_info"]["prompt_tokens"] |
| completion_tokens = ret["meta_info"]["completion_tokens"] |
| print( |
| f"{idx=}, {start_time=:.2f}, {prompt_tokens=}, " |
| f"{completion_tokens=}, {recorded_completion_tokens=}" |
| ) |
|
|
|
|
| def run_one_request(record): |
| |
|
|
| try: |
| run_one_request_internal(record) |
| |
| except Exception: |
| |
| traceback = get_exception_traceback() |
| print(f"Hit an exception: {traceback}") |
|
|
|
|
| def main(records): |
| if len(records) == 0: |
| return |
|
|
| base_time = records[0][-2] |
| base_time_str = datetime.fromtimestamp(base_time).strftime("%y-%m-%d %H:%M:%S") |
| print(f"{base_time_str=}") |
| replay_init_time = time.time() |
|
|
| for i in range(len(records)): |
| req, output, start_time, end_time = records[i] |
| start_time -= base_time |
| records[i] = (req, output, replay_init_time, start_time, end_time, i) |
|
|
| with ThreadPoolExecutor(args.parallel) as executor: |
| executor.map(run_one_request, records) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--host", type=str, default="localhost") |
| parser.add_argument("--port", type=int, default=30000) |
| parser.add_argument( |
| "--input-folder", type=str, default=None, help="Folder containing pickle files" |
| ) |
| parser.add_argument( |
| "--input-file", type=str, default=None, help="Single pickle file to process" |
| ) |
| parser.add_argument("--file-number", type=int, default=1) |
| parser.add_argument("--req-number", type=int, default=1000000) |
| parser.add_argument("--req-start", type=int, default=0) |
| parser.add_argument("--parallel", type=int, default=512) |
| parser.add_argument("--idx", type=int, default=None) |
| parser.add_argument("--ignore-eos", action="store_true") |
| parser.add_argument("--speed", type=float, default=1) |
| args = parser.parse_args() |
|
|
| set_ulimit() |
|
|
| files = [] |
| if args.input_file: |
| files = [args.input_file] |
| if args.file_number > 1: |
| print("Warning: --file-number is ignored when --input-file is provided.") |
| elif args.input_folder: |
| files = glob.glob(f"{args.input_folder}/*.pkl") |
| files = files[: args.file_number] |
| else: |
| print("Error: Either --input-folder or --input-file must be provided.") |
| exit(1) |
| print(f"{files=}") |
|
|
| records = read_records(files) |
| |
| records.sort(key=lambda x: x[-2]) |
| records = records[args.req_start :] |
| if args.idx: |
| records = [records[args.idx]] |
| print(f"testing {args.idx=}") |
| print(f"{records[0]}") |
| print(f"{len(records)=}") |
| main(records) |
|
|