| import atexit | |
| import json | |
| import multiprocessing | |
| import warnings | |
| from typing import Dict, List, Optional, Union | |
| import aiohttp | |
| import requests | |
| from sglang.global_config import global_config | |
| from sglang.lang.backend.base_backend import BaseBackend | |
| from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path | |
| from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod | |
| from sglang.lang.interpreter import StreamExecutor | |
| from sglang.lang.ir import ( | |
| REGEX_BOOL, | |
| REGEX_FLOAT, | |
| REGEX_INT, | |
| REGEX_STR, | |
| SglSamplingParams, | |
| ) | |
| from sglang.utils import http_request | |
| class RuntimeEndpoint(BaseBackend): | |
| def __init__( | |
| self, | |
| base_url: str, | |
| api_key: Optional[str] = None, | |
| verify: Optional[str] = None, | |
| chat_template_name: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| self.support_concate_and_append = True | |
| self.base_url = base_url | |
| self.api_key = api_key | |
| self.verify = verify | |
| res = http_request( | |
| self.base_url + "/get_model_info", | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| self.model_info = res.json() | |
| if chat_template_name: | |
| self.chat_template = get_chat_template(chat_template_name) | |
| else: | |
| self.chat_template = get_chat_template_by_model_path( | |
| self.model_info["model_path"] | |
| ) | |
| def get_model_name(self): | |
| return self.model_info["model_path"] | |
| def flush_cache(self): | |
| res = http_request( | |
| self.base_url + "/flush_cache", | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| method="POST", | |
| ) | |
| self._assert_success(res) | |
| def get_server_info(self): | |
| res = http_request( | |
| self.base_url + "/get_server_info", | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| return res.json() | |
| def get_chat_template(self): | |
| return self.chat_template | |
| def cache_prefix(self, prefix_str: str): | |
| res = http_request( | |
| self.base_url + "/generate", | |
| json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| def start_profile(self): | |
| res = http_request( | |
| self.base_url + "/start_profile", | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| def stop_profile(self): | |
| res = http_request( | |
| self.base_url + "/stop_profile", | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| def commit_lazy_operations(self, s: StreamExecutor): | |
| data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} | |
| self._add_images(s, data) | |
| res = http_request( | |
| self.base_url + "/generate", | |
| json=data, | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| def fill_image(self, s: StreamExecutor): | |
| data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} | |
| self._add_images(s, data) | |
| res = http_request( | |
| self.base_url + "/generate", | |
| json=data, | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams): | |
| if sampling_params.dtype is None: | |
| return | |
| if sampling_params.stop == (): | |
| sampling_params.stop = [] | |
| dtype_regex = None | |
| if sampling_params.dtype in ["int", int]: | |
| dtype_regex = REGEX_INT | |
| sampling_params.stop.extend([" ", "\n"]) | |
| elif sampling_params.dtype in ["float", float]: | |
| dtype_regex = REGEX_FLOAT | |
| sampling_params.stop.extend([" ", "\n"]) | |
| elif sampling_params.dtype in ["str", str]: | |
| dtype_regex = REGEX_STR | |
| elif sampling_params.dtype in ["bool", bool]: | |
| dtype_regex = REGEX_BOOL | |
| else: | |
| raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") | |
| if dtype_regex is not None and sampling_params.regex is not None: | |
| warnings.warn( | |
| f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}" | |
| ) | |
| sampling_params.regex = dtype_regex | |
| def generate( | |
| self, | |
| s: StreamExecutor, | |
| sampling_params: SglSamplingParams, | |
| ): | |
| self._handle_dtype_to_regex(sampling_params) | |
| data = { | |
| "text": s.text_, | |
| "sampling_params": { | |
| "skip_special_tokens": global_config.skip_special_tokens_in_output, | |
| "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, | |
| **sampling_params.to_srt_kwargs(), | |
| }, | |
| } | |
| for item in [ | |
| "return_logprob", | |
| "logprob_start_len", | |
| "top_logprobs_num", | |
| "return_text_in_logprobs", | |
| ]: | |
| value = getattr(sampling_params, item, None) | |
| if value is not None: | |
| data[item] = value | |
| self._add_images(s, data) | |
| res = http_request( | |
| self.base_url + "/generate", | |
| json=data, | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| obj = res.json() | |
| comp = obj["text"] | |
| return comp, obj["meta_info"] | |
| def generate_stream( | |
| self, | |
| s: StreamExecutor, | |
| sampling_params: SglSamplingParams, | |
| ): | |
| self._handle_dtype_to_regex(sampling_params) | |
| data = { | |
| "text": s.text_, | |
| "sampling_params": { | |
| "skip_special_tokens": global_config.skip_special_tokens_in_output, | |
| "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, | |
| **sampling_params.to_srt_kwargs(), | |
| }, | |
| } | |
| for item in [ | |
| "return_logprob", | |
| "logprob_start_len", | |
| "top_logprobs_num", | |
| "return_text_in_logprobs", | |
| ]: | |
| value = getattr(sampling_params, item, None) | |
| if value is not None: | |
| data[item] = value | |
| data["stream"] = True | |
| self._add_images(s, data) | |
| res = http_request( | |
| self.base_url + "/generate", | |
| json=data, | |
| stream=True, | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| pos = 0 | |
| for chunk in res.iter_lines(decode_unicode=False): | |
| chunk = chunk.decode("utf-8") | |
| if chunk and chunk.startswith("data:"): | |
| if chunk == "data: [DONE]": | |
| break | |
| data = json.loads(chunk[5:].strip("\n")) | |
| chunk_text = data["text"][pos:] | |
| meta_info = data["meta_info"] | |
| pos += len(chunk_text) | |
| yield chunk_text, meta_info | |
| def select( | |
| self, | |
| s: StreamExecutor, | |
| choices: List[str], | |
| temperature: float, | |
| choices_method: ChoicesSamplingMethod, | |
| ) -> ChoicesDecision: | |
| assert temperature <= 1e-5 | |
| # Cache common prefix | |
| data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} | |
| obj = self._generate_http_request(s, data) | |
| prompt_len = obj["meta_info"]["prompt_tokens"] | |
| logprob_start_len = max(prompt_len - 2, 0) # For token healing | |
| # Compute logprob | |
| data = { | |
| "text": [s.text_ + c for c in choices], | |
| "sampling_params": { | |
| "max_new_tokens": 0, | |
| "temperature": 0, | |
| }, | |
| "return_logprob": True, | |
| "return_text_in_logprobs": True, | |
| "logprob_start_len": logprob_start_len, | |
| } | |
| obj = self._generate_http_request(s, data) | |
| input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] | |
| output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] | |
| normalized_prompt_logprobs = [ | |
| compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"]) | |
| for r in obj | |
| ] | |
| # Remove extra token if no token healing occurred | |
| for i in range(len(input_token_logprobs)): | |
| healed_token_str = input_token_logprobs[i][0][-1] | |
| if s.text_.endswith(healed_token_str): | |
| healed_token_logprob = input_token_logprobs[i][0][0] | |
| normalized_prompt_logprobs[i] = ( | |
| normalized_prompt_logprobs[i] * len(input_token_logprobs[i]) | |
| - healed_token_logprob | |
| ) / (len(input_token_logprobs[i]) - 1) | |
| input_token_logprobs[i] = input_token_logprobs[i][1:] | |
| # Compute unconditional logprobs if required | |
| if choices_method.requires_unconditional_logprobs: | |
| input_ids = [[el[1] for el in subl] for subl in input_token_logprobs] | |
| data = { | |
| "input_ids": input_ids, | |
| "sampling_params": {"max_new_tokens": 0}, | |
| "return_logprob": True, | |
| } | |
| obj = self._generate_http_request(s, data) | |
| unconditional_token_logprobs = [ | |
| r["meta_info"]["input_token_logprobs"] for r in obj | |
| ] | |
| else: | |
| unconditional_token_logprobs = None | |
| return choices_method( | |
| choices=choices, | |
| normalized_prompt_logprobs=normalized_prompt_logprobs, | |
| input_token_logprobs=input_token_logprobs, | |
| output_token_logprobs=output_token_logprobs, | |
| unconditional_token_logprobs=unconditional_token_logprobs, | |
| ) | |
| def concatenate_and_append(self, src_rids: List[str], dst_rid: str): | |
| res = http_request( | |
| self.base_url + "/concate_and_append_request", | |
| json={"src_rids": src_rids, "dst_rid": dst_rid}, | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| def _generate_http_request(self, s: StreamExecutor, data): | |
| self._add_images(s, data) | |
| res = http_request( | |
| self.base_url + "/generate", | |
| json=data, | |
| api_key=self.api_key, | |
| verify=self.verify, | |
| ) | |
| self._assert_success(res) | |
| return res.json() | |
| def _add_images(self, s: StreamExecutor, data): | |
| if s.images_: | |
| assert len(s.images_) == 1, "Only support one image." | |
| data["image_data"] = s.images_[0][1] | |
| def _assert_success(self, res): | |
| if res.status_code != 200: | |
| try: | |
| content = res.json() | |
| except json.JSONDecodeError: | |
| content = res.text | |
| raise RuntimeError(content) | |
| def compute_normalized_prompt_logprobs(input_logprobs): | |
| values = [x[0] for x in input_logprobs if x[0]] | |
| return sum(values) / len(values) | |
| class Runtime: | |
| """ | |
| A wrapper for the HTTP server. | |
| This is used for launching the server in a python program without | |
| using the command line interface. | |
| It is mainly used for the frontend language. | |
| You should use the Engine class if you want to do normal offline processing without the frontend language. | |
| """ | |
| def __init__( | |
| self, | |
| log_level: str = "error", | |
| *args, | |
| **kwargs, | |
| ): | |
| """See the arguments in server_args.py::ServerArgs""" | |
| # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run | |
| # client code without installing SRT server and its dependency if they want. | |
| from sglang.srt.entrypoints.http_server import launch_server | |
| from sglang.srt.server_args import ServerArgs | |
| from sglang.srt.utils import is_port_available | |
| self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) | |
| # Pre-allocate ports | |
| for port in range(self.server_args.port, 40000): | |
| if is_port_available(port): | |
| break | |
| self.server_args.port = port | |
| self.url = self.server_args.url() | |
| self.generate_url = self.url + "/generate" | |
| # NOTE: We store pid instead of proc to fix some issues during __delete__ | |
| self.pid = None | |
| pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False) | |
| ctx = multiprocessing.get_context("spawn") | |
| proc = ctx.Process( | |
| target=launch_server, | |
| args=(self.server_args, pipe_writer), | |
| ) | |
| proc.start() | |
| pipe_writer.close() | |
| self.pid = proc.pid | |
| # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() | |
| atexit.register(self.shutdown) | |
| # TODO: remove this pipe_writer mechanism and use `/health_generate` instead. | |
| try: | |
| init_state = pipe_reader.recv() | |
| except EOFError: | |
| init_state = "" | |
| if init_state != "ready": | |
| self.shutdown() | |
| raise RuntimeError( | |
| "Initialization failed. Please see the error messages above." | |
| ) | |
| self.endpoint = RuntimeEndpoint(self.url) | |
| def shutdown(self): | |
| from sglang.srt.utils import kill_process_tree | |
| if self.pid is not None: | |
| kill_process_tree(self.pid) | |
| self.pid = None | |
| def start_profile(self): | |
| self.endpoint.start_profile() | |
| def stop_profile(self): | |
| self.endpoint.stop_profile() | |
| def cache_prefix(self, prefix: str): | |
| self.endpoint.cache_prefix(prefix) | |
| def get_tokenizer(self): | |
| from sglang.srt.utils.hf_transformers_utils import get_tokenizer | |
| return get_tokenizer( | |
| self.server_args.tokenizer_path, | |
| tokenizer_mode=self.server_args.tokenizer_mode, | |
| trust_remote_code=self.server_args.trust_remote_code, | |
| revision=self.server_args.revision, | |
| ) | |
| async def async_generate( | |
| self, | |
| prompt: str, | |
| sampling_params: Optional[Dict] = None, | |
| ): | |
| if self.server_args.skip_tokenizer_init: | |
| json_data = { | |
| "input_ids": prompt, | |
| "sampling_params": sampling_params, | |
| "stream": True, | |
| } | |
| else: | |
| json_data = { | |
| "text": prompt, | |
| "sampling_params": sampling_params, | |
| "stream": True, | |
| } | |
| pos = 0 | |
| timeout = aiohttp.ClientTimeout(total=3 * 3600) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| async with session.post(self.generate_url, json=json_data) as response: | |
| async for chunk, _ in response.content.iter_chunks(): | |
| chunk = chunk.decode("utf-8") | |
| if chunk and chunk.startswith("data:"): | |
| if chunk == "data: [DONE]\n\n": | |
| break | |
| data = json.loads(chunk[5:].strip("\n")) | |
| if "text" in data: | |
| cur = data["text"][pos:] | |
| if cur: | |
| yield cur | |
| pos += len(cur) | |
| else: | |
| yield data | |
| add_request = async_generate | |
| def generate( | |
| self, | |
| prompt: Union[str, List[str]], | |
| sampling_params: Optional[Dict] = None, | |
| return_logprob: Optional[Union[List[bool], bool]] = False, | |
| logprob_start_len: Optional[Union[List[int], int]] = None, | |
| top_logprobs_num: Optional[Union[List[int], int]] = None, | |
| lora_path: Optional[List[Optional[str]]] = None, | |
| ): | |
| json_data = { | |
| "text": prompt, | |
| "sampling_params": sampling_params, | |
| "return_logprob": return_logprob, | |
| "logprob_start_len": logprob_start_len, | |
| "top_logprobs_num": top_logprobs_num, | |
| "lora_path": lora_path, | |
| } | |
| assert not isinstance(lora_path, list) or len(lora_path) == len(prompt) | |
| response = requests.post( | |
| self.url + "/generate", | |
| json=json_data, | |
| ) | |
| return json.dumps(response.json()) | |
| def encode( | |
| self, | |
| prompt: Union[str, List[str], List[Dict], List[List[Dict]]], | |
| ): | |
| json_data = {"text": prompt} | |
| response = requests.post(self.url + "/encode", json=json_data) | |
| return json.dumps(response.json()) | |
| async def get_server_info(self): | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(f"{self.url}/get_server_info") as response: | |
| if response.status == 200: | |
| return await response.json() | |
| else: | |
| error_data = await response.json() | |
| raise RuntimeError( | |
| f"Failed to get server info. {error_data['error']['message']}" | |
| ) | |
| def __del__(self): | |
| self.shutdown() | |
Xet Storage Details
- Size:
- 17.5 kB
- Xet hash:
- 3a42bf1ffebc84f6276ad7b81bfddcaf39b1e5431c36a9c374e46deb6bbf72a7
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.