| """ | |
| Benchmark the latency of running a single static batch without a server. | |
| This script does not launch a server and uses the low-level APIs. | |
| It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths). | |
| # Usage (latency test) | |
| ## with dummy weights: | |
| python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy | |
| ## sweep through multiple data points and store (append) the results in a jsonl file: | |
| python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run | |
| ## run with profiling: | |
| python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile | |
| # Usage (correctness test): | |
| python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct | |
| ## Reference output (of the correctness test above, can be gpu dependent): | |
| input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]] | |
| prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], | |
| [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], | |
| [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]], | |
| device='cuda:0') | |
| prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141], | |
| [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781], | |
| [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]], | |
| device='cuda:0') | |
| ========== Prompt 0 ========== | |
| <s> The capital of France is Paris. | |
| The capital of the United States is Washington, D.C. | |
| ========== Prompt 1 ========== | |
| <s> The capital of the United Kindom is London. | |
| The capital of the United Kingdom is London. | |
| The capital of the | |
| ========== Prompt 2 ========== | |
| <s> Today is a sunny day and I like to go for a walk in the park. | |
| I'm going to the park | |
| """ | |
| import argparse | |
| import copy | |
| import dataclasses | |
| import itertools | |
| import json | |
| import logging | |
| import multiprocessing | |
| import os | |
| import time | |
| from types import SimpleNamespace | |
| from typing import Tuple | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from sglang.srt.configs.model_config import ModelConfig | |
| from sglang.srt.distributed.parallel_state import destroy_distributed_environment | |
| from sglang.srt.entrypoints.engine import _set_envs_and_config | |
| from sglang.srt.layers.moe import initialize_moe_config | |
| from sglang.srt.managers.schedule_batch import Req, ScheduleBatch | |
| from sglang.srt.managers.scheduler import Scheduler | |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | |
| from sglang.srt.model_executor.model_runner import ModelRunner | |
| from sglang.srt.sampling.sampling_params import SamplingParams | |
| from sglang.srt.server_args import PortArgs, ServerArgs | |
| from sglang.srt.speculative.spec_info import SpeculativeAlgorithm | |
| from sglang.srt.utils import ( | |
| configure_logger, | |
| get_bool_env_var, | |
| is_cuda_alike, | |
| is_xpu, | |
| kill_process_tree, | |
| maybe_reindex_device_id, | |
| require_mlp_sync, | |
| require_mlp_tp_gather, | |
| set_gpu_proc_affinity, | |
| suppress_other_loggers, | |
| ) | |
| from sglang.srt.utils.hf_transformers_utils import get_tokenizer | |
| profile_activities = [torch.profiler.ProfilerActivity.CPU] + [ | |
| profiler_activity | |
| for available, profiler_activity in [ | |
| (is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA), | |
| (is_xpu(), torch.profiler.ProfilerActivity.XPU), | |
| ] | |
| if available | |
| ] | |
| class BenchArgs: | |
| run_name: str = "default" | |
| batch_size: Tuple[int] = (1,) | |
| input_len: Tuple[int] = (1024,) | |
| output_len: Tuple[int] = (16,) | |
| prompt_filename: str = "" | |
| result_filename: str = "result.jsonl" | |
| correctness_test: bool = False | |
| # This is only used for correctness test | |
| cut_len: int = 4 | |
| log_decode_step: int = 0 | |
| profile: bool = False | |
| profile_record_shapes: bool = False | |
| profile_filename_prefix: str = "profile" | |
| def add_cli_args(parser: argparse.ArgumentParser): | |
| parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) | |
| parser.add_argument( | |
| "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size | |
| ) | |
| parser.add_argument( | |
| "--input-len", type=int, nargs="+", default=BenchArgs.input_len | |
| ) | |
| parser.add_argument( | |
| "--output-len", type=int, nargs="+", default=BenchArgs.output_len | |
| ) | |
| parser.add_argument( | |
| "--prompt-filename", type=str, default=BenchArgs.prompt_filename | |
| ) | |
| parser.add_argument( | |
| "--result-filename", type=str, default=BenchArgs.result_filename | |
| ) | |
| parser.add_argument("--correctness-test", action="store_true") | |
| parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) | |
| parser.add_argument( | |
| "--log-decode-step", | |
| type=int, | |
| default=BenchArgs.log_decode_step, | |
| help="Log decode latency by step, default is set to zero to disable.", | |
| ) | |
| parser.add_argument( | |
| "--profile", action="store_true", help="Use Torch Profiler." | |
| ) | |
| parser.add_argument( | |
| "--profile-record-shapes", | |
| action="store_true", | |
| help="Record tensor shapes in profiling results.", | |
| ) | |
| parser.add_argument( | |
| "--profile-filename-prefix", | |
| type=str, | |
| default=BenchArgs.profile_filename_prefix, | |
| help="Prefix of the profiling file names. The full profiling result file(s) be " | |
| '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"', | |
| ) | |
| def from_cli_args(cls, args: argparse.Namespace): | |
| # use the default value's type to cast the args into correct types. | |
| attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] | |
| return cls( | |
| **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} | |
| ) | |
| def load_model(server_args, port_args, gpu_id, tp_rank): | |
| suppress_other_loggers() | |
| rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None | |
| moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) | |
| model_config = ModelConfig.from_server_args(server_args) | |
| model_runner = ModelRunner( | |
| model_config=model_config, | |
| mem_fraction_static=server_args.mem_fraction_static, | |
| gpu_id=gpu_id, | |
| tp_rank=tp_rank, | |
| tp_size=server_args.tp_size, | |
| moe_ep_rank=moe_ep_rank, | |
| moe_ep_size=server_args.ep_size, | |
| pp_rank=0, | |
| pp_size=1, | |
| nccl_port=port_args.nccl_port, | |
| server_args=server_args, | |
| ) | |
| rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") | |
| tokenizer = get_tokenizer( | |
| server_args.tokenizer_path, | |
| tokenizer_mode=server_args.tokenizer_mode, | |
| trust_remote_code=server_args.trust_remote_code, | |
| ) | |
| if server_args.tp_size > 1: | |
| dist.barrier() | |
| return model_runner, tokenizer | |
| def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts): | |
| prompts = ( | |
| custom_prompts | |
| if custom_prompts | |
| else [ | |
| "The capital of France is", | |
| "The capital of the United Kindom is", | |
| "Today is a sunny day and I like", | |
| ] | |
| ) | |
| input_ids = [tokenizer.encode(p) for p in prompts] | |
| sampling_params = SamplingParams( | |
| temperature=0, | |
| max_new_tokens=BenchArgs.output_len, | |
| ) | |
| reqs = [] | |
| for i in range(len(prompts)): | |
| assert len(input_ids[i]) > bench_args.cut_len | |
| tmp_input_ids = input_ids[i][: bench_args.cut_len] | |
| req = Req( | |
| rid=i, | |
| origin_input_text=prompts[i], | |
| origin_input_ids=tmp_input_ids, | |
| sampling_params=sampling_params, | |
| ) | |
| req.fill_ids = req.origin_input_ids | |
| req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) | |
| req.logprob_start_len = len(req.origin_input_ids) - 1 | |
| reqs.append(req) | |
| return input_ids, reqs | |
| def prepare_extend_inputs_for_correctness_test( | |
| bench_args, input_ids, reqs, model_runner | |
| ): | |
| for i in range(len(reqs)): | |
| req = reqs[i] | |
| req.fill_ids += input_ids[i][bench_args.cut_len :] | |
| req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ | |
| i, : bench_args.cut_len | |
| ] | |
| req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) | |
| req.logprob_start_len = len(req.origin_input_ids) - 1 | |
| return reqs | |
| def prepare_synthetic_inputs_for_latency_test( | |
| batch_size, input_len, custom_inputs=None | |
| ): | |
| input_ids = ( | |
| custom_inputs | |
| if custom_inputs | |
| else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32) | |
| ) | |
| sampling_params = SamplingParams( | |
| temperature=0, | |
| max_new_tokens=BenchArgs.output_len, | |
| ) | |
| reqs = [] | |
| for i in range(len(input_ids)): | |
| req = Req( | |
| rid=i, | |
| origin_input_text="", | |
| origin_input_ids=list(input_ids[i]), | |
| sampling_params=sampling_params, | |
| ) | |
| req.fill_ids = req.origin_input_ids | |
| req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) | |
| req.logprob_start_len = len(req.origin_input_ids) - 1 | |
| reqs.append(req) | |
| return reqs | |
| def extend(reqs, model_runner): | |
| # Create dummy tree_cache for benchmarks (no prefix caching, just allocation) | |
| dummy_tree_cache = SimpleNamespace( | |
| page_size=model_runner.server_args.page_size, | |
| device=model_runner.device, | |
| token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, | |
| ) | |
| batch = ScheduleBatch.init_new( | |
| reqs=reqs, | |
| req_to_token_pool=model_runner.req_to_token_pool, | |
| token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, | |
| tree_cache=dummy_tree_cache, | |
| model_config=model_runner.model_config, | |
| enable_overlap=False, | |
| spec_algorithm=SpeculativeAlgorithm.NONE, | |
| ) | |
| batch.prepare_for_extend() | |
| _maybe_prepare_mlp_sync_batch(batch, model_runner) | |
| model_worker_batch = batch.get_model_worker_batch() | |
| forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) | |
| logits_output, _ = model_runner.forward(forward_batch) | |
| next_token_ids = model_runner.sample(logits_output, forward_batch) | |
| return next_token_ids, logits_output.next_token_logits, batch | |
| def decode(input_token_ids, batch, model_runner): | |
| batch.output_ids = input_token_ids | |
| batch.prepare_for_decode() | |
| _maybe_prepare_mlp_sync_batch(batch, model_runner) | |
| model_worker_batch = batch.get_model_worker_batch() | |
| forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) | |
| logits_output, _ = model_runner.forward(forward_batch) | |
| next_token_ids = model_runner.sample(logits_output, forward_batch) | |
| return next_token_ids, logits_output.next_token_logits | |
| def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): | |
| if require_mlp_sync(model_runner.server_args): | |
| Scheduler.prepare_mlp_sync_batch_raw( | |
| batch, | |
| dp_size=model_runner.server_args.dp_size, | |
| attn_tp_size=1, | |
| tp_group=model_runner.tp_group, | |
| get_idle_batch=None, | |
| disable_cuda_graph=model_runner.server_args.disable_cuda_graph, | |
| spec_algorithm=SpeculativeAlgorithm.NONE, | |
| speculative_num_draft_tokens=None, | |
| require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args), | |
| disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule, | |
| offload_tags=set(), | |
| ) | |
| def _read_prompts_from_file(prompt_file, rank_print): | |
| """Read custom prompts from the file specified by `--prompt-filename`.""" | |
| if not prompt_file: | |
| return [] | |
| if not os.path.exists(prompt_file): | |
| rank_print( | |
| f"Custom prompt file {prompt_file} not found. Using default inputs..." | |
| ) | |
| return [] | |
| with open(prompt_file, "r") as pf: | |
| return pf.readlines() | |
| def _save_profile_trace_results(profiler, filename): | |
| parent_dir = os.path.dirname(os.path.abspath(filename)) | |
| os.makedirs(parent_dir, exist_ok=True) | |
| profiler.export_chrome_trace(filename) | |
| print( | |
| profiler.key_averages(group_by_input_shape=True).table( | |
| sort_by="self_cpu_time_total" | |
| ) | |
| ) | |
| def correctness_test( | |
| server_args, | |
| port_args, | |
| bench_args, | |
| gpu_id, | |
| tp_rank, | |
| ): | |
| # Configure the logger | |
| configure_logger(server_args, prefix=f" TP{tp_rank}") | |
| rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None | |
| # Load the model | |
| model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank) | |
| # Prepare inputs | |
| custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print) | |
| input_ids, reqs = prepare_inputs_for_correctness_test( | |
| bench_args, tokenizer, custom_prompts | |
| ) | |
| rank_print(f"\n{input_ids=}\n") | |
| if bench_args.cut_len > 0: | |
| # Prefill | |
| next_token_ids, next_token_logits, batch = extend(reqs, model_runner) | |
| rank_print(f"prefill logits (first half): {next_token_logits} \n") | |
| # Prepare extend inputs | |
| reqs = prepare_extend_inputs_for_correctness_test( | |
| bench_args, input_ids, reqs, model_runner | |
| ) | |
| # Extend (prefill w/ KV cache) | |
| next_token_ids, next_token_logits, batch = extend(reqs, model_runner) | |
| rank_print(f"prefill logits (final): {next_token_logits} \n") | |
| # Decode | |
| output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] | |
| for _ in range(bench_args.output_len[0] - 1): | |
| next_token_ids, _ = decode(next_token_ids, batch, model_runner) | |
| next_token_ids_list = next_token_ids.tolist() | |
| for i in range(len(reqs)): | |
| output_ids[i].append(next_token_ids_list[i]) | |
| # Print output texts | |
| for i in range(len(reqs)): | |
| rank_print(f"========== Prompt {i} ==========") | |
| rank_print(tokenizer.decode(output_ids[i]), "\n") | |
| def synchronize(device): | |
| torch.get_device_module(device).synchronize() | |
| def latency_test_run_once( | |
| run_name, | |
| model_runner, | |
| rank_print, | |
| reqs, | |
| batch_size, | |
| input_len, | |
| output_len, | |
| device, | |
| log_decode_step, | |
| profile, | |
| profile_record_shapes, | |
| profile_filename_prefix, | |
| ): | |
| max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) | |
| if batch_size > max_batch_size: | |
| rank_print( | |
| f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit" | |
| ) | |
| return | |
| # Clear the pools. | |
| model_runner.req_to_token_pool.clear() | |
| model_runner.token_to_kv_pool_allocator.clear() | |
| measurement_results = { | |
| "run_name": run_name, | |
| "batch_size": batch_size, | |
| "input_len": input_len, | |
| "output_len": output_len, | |
| } | |
| tot_latency = 0 | |
| profiler = None | |
| if profile: | |
| profiler = torch.profiler.profile( | |
| activities=profile_activities, | |
| with_stack=True, | |
| record_shapes=profile_record_shapes, | |
| ) | |
| profiler.start() | |
| # Prefill | |
| synchronize(device) | |
| tic = time.perf_counter() | |
| next_token_ids, _, batch = extend(reqs, model_runner) | |
| synchronize(device) | |
| prefill_latency = time.perf_counter() - tic | |
| tot_latency += prefill_latency | |
| throughput = input_len * batch_size / prefill_latency | |
| rank_print( | |
| f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" | |
| ) | |
| measurement_results["prefill_latency"] = prefill_latency | |
| measurement_results["prefill_throughput"] = throughput | |
| if profile: | |
| profiler.stop() | |
| trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz" | |
| _save_profile_trace_results(profiler, trace_filename) | |
| rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}") | |
| # Decode | |
| decode_latencies = [] | |
| for i in range(output_len - 1): | |
| synchronize(device) | |
| if profile and i == output_len / 2: | |
| profiler = None | |
| profiler = torch.profiler.profile( | |
| activities=profile_activities, | |
| with_stack=True, | |
| record_shapes=profile_record_shapes, | |
| ) | |
| profiler.start() | |
| tic = time.perf_counter() | |
| next_token_ids, _ = decode(next_token_ids, batch, model_runner) | |
| synchronize(device) | |
| latency = time.perf_counter() - tic | |
| tot_latency += latency | |
| throughput = batch_size / latency | |
| decode_latencies.append(latency) | |
| if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0): | |
| rank_print( | |
| f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" | |
| ) | |
| if profile and i == output_len / 2: | |
| profiler.stop() | |
| trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz" | |
| _save_profile_trace_results(profiler, trace_filename) | |
| rank_print( | |
| f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}" | |
| ) | |
| # Record decode timing from 2nd output | |
| if output_len > 1: | |
| med_decode_latency = np.median(decode_latencies) | |
| med_decode_throughput = batch_size / med_decode_latency | |
| rank_print( | |
| f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s" | |
| ) | |
| measurement_results["median_decode_latency"] = med_decode_latency | |
| measurement_results["median_decode_throughput"] = med_decode_throughput | |
| throughput = (input_len + output_len) * batch_size / tot_latency | |
| rank_print( | |
| f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" | |
| ) | |
| measurement_results["total_latency"] = tot_latency | |
| measurement_results["overall_throughput"] = throughput | |
| return measurement_results | |
| def latency_test( | |
| server_args, | |
| port_args, | |
| bench_args, | |
| gpu_id, | |
| tp_rank, | |
| ): | |
| initialize_moe_config(server_args) | |
| # Set CPU affinity | |
| if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): | |
| set_gpu_proc_affinity( | |
| server_args.pp_size, server_args.tp_size, server_args.nnodes, tp_rank | |
| ) | |
| # Configure the logger | |
| configure_logger(server_args, prefix=f" TP{tp_rank}") | |
| rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None | |
| # Load the model | |
| model_runner, tokenizer = load_model(server_args, port_args, gpu_id, tp_rank) | |
| # Prepare inputs for warm up | |
| reqs = prepare_synthetic_inputs_for_latency_test( | |
| bench_args.batch_size[0], bench_args.input_len[0] | |
| ) | |
| # Warm up | |
| rank_print("Warmup ...") | |
| latency_test_run_once( | |
| bench_args.run_name, | |
| model_runner, | |
| rank_print, | |
| reqs, | |
| bench_args.batch_size[0], | |
| bench_args.input_len[0], | |
| min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup | |
| server_args.device, | |
| log_decode_step=0, | |
| profile=False, | |
| profile_record_shapes=False, | |
| profile_filename_prefix="", # not used | |
| ) | |
| rank_print("Benchmark ...") | |
| custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print) | |
| custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs] | |
| custom_input_len = len(custom_inputs) | |
| # Run the sweep | |
| result_list = [] | |
| for bs, il, ol in itertools.product( | |
| bench_args.batch_size, bench_args.input_len, bench_args.output_len | |
| ): | |
| bs_aligned_inputs = [] | |
| if custom_inputs: | |
| if custom_input_len == bs: | |
| bs_aligned_inputs = custom_inputs | |
| elif custom_input_len > bs: | |
| rank_print( | |
| f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). " | |
| f"Using the first {bs} prompts." | |
| ) | |
| bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs]) | |
| else: | |
| rank_print( | |
| f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). " | |
| f"Pad to the desired batch_size with the last prompt." | |
| ) | |
| bs_aligned_inputs = copy.deepcopy(custom_inputs) | |
| bs_aligned_inputs.extend( | |
| [bs_aligned_inputs[-1]] * (bs - custom_input_len) | |
| ) | |
| reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs) | |
| ret = latency_test_run_once( | |
| bench_args.run_name, | |
| model_runner, | |
| rank_print, | |
| reqs, | |
| bs, | |
| il, | |
| ol, | |
| server_args.device, | |
| bench_args.log_decode_step, | |
| bench_args.profile if tp_rank == 0 else None, | |
| bench_args.profile_record_shapes if tp_rank == 0 else None, | |
| bench_args.profile_filename_prefix, | |
| ) | |
| if ret is not None: | |
| result_list.append(ret) | |
| # Write results in jsonlines format on rank 0. | |
| if tp_rank == 0 and bench_args.result_filename: | |
| with open(bench_args.result_filename, "a") as fout: | |
| for result in result_list: | |
| fout.write(json.dumps(result) + "\n") | |
| if server_args.tp_size > 1: | |
| destroy_distributed_environment() | |
| def main(server_args, bench_args): | |
| server_args.cuda_graph_max_bs = max(bench_args.batch_size) | |
| _set_envs_and_config(server_args) | |
| if server_args.model_path: | |
| if bench_args.correctness_test: | |
| work_func = correctness_test | |
| else: | |
| work_func = latency_test | |
| else: | |
| raise ValueError( | |
| "Provide --model-path for running the tests or " | |
| "provide --result-filename for plotting the results" | |
| ) | |
| port_args = PortArgs.init_new(server_args) | |
| if server_args.tp_size == 1: | |
| work_func(server_args, port_args, bench_args, 0, 0) | |
| else: | |
| workers = [] | |
| for tp_rank in range(server_args.tp_size): | |
| with maybe_reindex_device_id(tp_rank) as gpu_id: | |
| proc = multiprocessing.Process( | |
| target=work_func, | |
| args=( | |
| server_args, | |
| port_args, | |
| bench_args, | |
| gpu_id, | |
| tp_rank, | |
| ), | |
| ) | |
| proc.start() | |
| workers.append(proc) | |
| for proc in workers: | |
| proc.join() | |
| proc.terminate() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| ServerArgs.add_cli_args(parser) | |
| BenchArgs.add_cli_args(parser) | |
| args = parser.parse_args() | |
| server_args = ServerArgs.from_cli_args(args) | |
| bench_args = BenchArgs.from_cli_args(args) | |
| logging.basicConfig( | |
| level=getattr(logging, server_args.log_level.upper()), | |
| format="%(message)s", | |
| ) | |
| try: | |
| main(server_args, bench_args) | |
| finally: | |
| if server_args.tp_size != 1: | |
| kill_process_tree(os.getpid(), include_parent=False) | |
Xet Storage Details
- Size:
- 23.8 kB
- Xet hash:
- de5d799c957a9d2b89e699465edba7eefd11bfb0889ed9c344139bde8f770940
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.