| | """ |
| | A model worker that executes the model based on LightLLM. |
| | |
| | See documentations at docs/lightllm_integration.md |
| | """ |
| |
|
| | import argparse |
| | import asyncio |
| | import json |
| | import os |
| | import torch |
| | import uvicorn |
| |
|
| | from transformers import AutoConfig |
| |
|
| | from typing import List |
| |
|
| | from fastapi import FastAPI, Request, BackgroundTasks |
| | from fastapi.responses import StreamingResponse, JSONResponse |
| |
|
| | from src.serve.base_model_worker import BaseModelWorker |
| | from src.serve.model_worker import ( |
| | logger, |
| | worker_id, |
| | ) |
| |
|
| | from lightllm.server.sampling_params import SamplingParams |
| | from lightllm.server.multimodal_params import MultimodalParams |
| | from lightllm.server.httpserver.manager import HttpServerManager |
| | from lightllm.server.detokenization.manager import start_detokenization_process |
| | from lightllm.server.router.manager import start_router_process |
| | from lightllm.server.req_id_generator import ReqIDGenerator |
| |
|
| | from lightllm.utils.net_utils import alloc_can_use_network_port |
| | from lightllm.utils.start_utils import start_submodule_processes |
| | from fastchat.utils import get_context_length, is_partial_stop |
| |
|
| | app = FastAPI() |
| | g_id_gen = ReqIDGenerator() |
| |
|
| |
|
| | class LightLLMWorker(BaseModelWorker): |
| | def __init__( |
| | self, |
| | controller_addr: str, |
| | worker_addr: str, |
| | worker_id: str, |
| | model_path: str, |
| | model_names: List[str], |
| | limit_worker_concurrency: int, |
| | no_register: bool, |
| | conv_template: str, |
| | tokenizer, |
| | context_len, |
| | ): |
| | super().__init__( |
| | controller_addr, |
| | worker_addr, |
| | worker_id, |
| | model_path, |
| | model_names, |
| | limit_worker_concurrency, |
| | conv_template, |
| | ) |
| |
|
| | logger.info( |
| | f"Loading the model {self.model_names} on worker {worker_id}, worker type: LightLLM worker..." |
| | ) |
| | self.tokenizer = tokenizer |
| | self.context_len = context_len |
| |
|
| | self.is_first = True |
| |
|
| | if not no_register: |
| | self.init_heart_beat() |
| |
|
| | async def generate_stream(self, params): |
| | self.call_ct += 1 |
| |
|
| | prompt = params.pop("prompt") |
| | request_id = params.pop("request_id") |
| | temperature = float(params.get("temperature", 1.0)) |
| | top_p = float(params.get("top_p", 1.0)) |
| | top_k = params.get("top_k", -1.0) |
| | presence_penalty = float(params.get("presence_penalty", 0.0)) |
| | frequency_penalty = float(params.get("frequency_penalty", 0.0)) |
| | repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
| | max_new_tokens = params.get("max_new_tokens", 256) |
| | echo = params.get("echo", True) |
| | stop_str = params.get("stop", None) |
| | stop_token_ids = params.get("stop_token_ids", None) or [] |
| | if self.tokenizer.eos_token_id is not None: |
| | stop_token_ids.append(self.tokenizer.eos_token_id) |
| |
|
| | request = params.get("request", None) |
| |
|
| | |
| | stop = set() |
| | if isinstance(stop_str, str) and stop_str != "": |
| | stop.add(stop_str) |
| | elif isinstance(stop_str, list) and stop_str != []: |
| | stop.update(stop_str) |
| |
|
| | for tid in stop_token_ids: |
| | if tid is not None: |
| | s = self.tokenizer.decode(tid) |
| | if s != "": |
| | stop.add(s) |
| |
|
| | if self.is_first: |
| | loop = asyncio.get_event_loop() |
| | loop.create_task(httpserver_manager.handle_loop()) |
| | self.is_first = False |
| |
|
| | |
| | top_p = max(top_p, 1e-5) |
| | if temperature <= 1e-5: |
| | top_p = 1.0 |
| |
|
| | sampling_params = SamplingParams( |
| | do_sample=temperature > 0.0, |
| | temperature=temperature, |
| | top_p=top_p, |
| | top_k=top_k, |
| | presence_penalty=presence_penalty, |
| | frequency_penalty=frequency_penalty, |
| | repetition_penalty=repetition_penalty, |
| | max_new_tokens=max_new_tokens, |
| | stop_sequences=list(stop), |
| | ) |
| | sampling_params.verify() |
| |
|
| | results_generator = httpserver_manager.generate( |
| | prompt, sampling_params, request_id, MultimodalParams() |
| | ) |
| |
|
| | completion_tokens = 0 |
| | text_outputs = "" |
| | cumulative_logprob = 0.0 |
| |
|
| | async for request_output, metadata, finish_status in results_generator: |
| | text_outputs += request_output |
| | completion_tokens += 1 |
| |
|
| | partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) |
| | |
| | if partial_stop: |
| | continue |
| |
|
| | if type(finish_status) is bool: |
| | finish_reason = "stop" if finish_status else None |
| | else: |
| | finish_reason = finish_status.get_finish_reason() |
| |
|
| | if request and await request.is_disconnected(): |
| | await httpserver_manager.abort(request_id) |
| | finish_reason = "abort" |
| |
|
| | logprob = metadata.get("logprob", None) |
| | if logprob is not None: |
| | cumulative_logprob += logprob |
| |
|
| | prompt_tokens = metadata["prompt_tokens"] |
| | ret = { |
| | "text": prompt + text_outputs if echo else text_outputs, |
| | "error_code": 0, |
| | "usage": { |
| | "prompt_tokens": prompt_tokens, |
| | "completion_tokens": completion_tokens, |
| | "total_tokens": prompt_tokens + completion_tokens, |
| | }, |
| | "cumulative_logprob": cumulative_logprob, |
| | } |
| |
|
| | if finish_reason is not None: |
| | yield ( |
| | json.dumps({**ret, "finish_reason": None}, ensure_ascii=False) |
| | + "\0" |
| | ).encode("utf-8") |
| | yield ( |
| | json.dumps({**ret, "finish_reason": finish_reason}, ensure_ascii=False) |
| | + "\0" |
| | ).encode("utf-8") |
| |
|
| | if finish_reason is not None: |
| | break |
| |
|
| | async def generate(self, params): |
| | async for x in self.generate_stream(params): |
| | pass |
| | return json.loads(x[:-1].decode()) |
| |
|
| |
|
| | def release_worker_semaphore(): |
| | worker.semaphore.release() |
| |
|
| |
|
| | def acquire_worker_semaphore(): |
| | if worker.semaphore is None: |
| | worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) |
| | return worker.semaphore.acquire() |
| |
|
| |
|
| | def create_background_tasks(request_id): |
| | async def abort_request() -> None: |
| | await httpserver_manager.abort(request_id) |
| |
|
| | background_tasks = BackgroundTasks() |
| | background_tasks.add_task(release_worker_semaphore) |
| | background_tasks.add_task(abort_request) |
| | return background_tasks |
| |
|
| |
|
| | @app.post("/worker_generate_stream") |
| | async def api_generate_stream(request: Request): |
| | params = await request.json() |
| | await acquire_worker_semaphore() |
| | request_id = g_id_gen.generate_id() |
| | params["request_id"] = request_id |
| | params["request"] = request |
| | generator = worker.generate_stream(params) |
| | background_tasks = create_background_tasks(request_id) |
| | return StreamingResponse(generator, background=background_tasks) |
| |
|
| |
|
| | @app.post("/worker_generate") |
| | async def api_generate(request: Request): |
| | params = await request.json() |
| | await acquire_worker_semaphore() |
| | request_id = g_id_gen.generate_id() |
| | params["request_id"] = request_id |
| | params["request"] = request |
| | output = await worker.generate(params) |
| | release_worker_semaphore() |
| | await httpserver_manager.abort(request_id) |
| | return JSONResponse(output) |
| |
|
| |
|
| | @app.post("/worker_get_status") |
| | async def api_get_status(request: Request): |
| | return worker.get_status() |
| |
|
| |
|
| | @app.post("/count_token") |
| | async def api_count_token(request: Request): |
| | params = await request.json() |
| | return worker.count_token(params) |
| |
|
| |
|
| | @app.post("/worker_get_conv_template") |
| | async def api_get_conv(request: Request): |
| | return worker.get_conv_template() |
| |
|
| |
|
| | @app.post("/model_details") |
| | async def api_model_details(request: Request): |
| | return {"context_length": worker.context_len} |
| |
|
| |
|
| | if __name__ == "__main__": |
| | torch.multiprocessing.set_start_method("spawn") |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--host", type=str, default="127.0.0.1") |
| | parser.add_argument("--port", type=int, default=8000) |
| |
|
| | parser.add_argument( |
| | "--model-path", |
| | dest="model_dir", |
| | type=str, |
| | default=None, |
| | help="the model weight dir path, the app will load config, weights and tokenizer from this dir", |
| | ) |
| | parser.add_argument("--worker-address", type=str, default="http://localhost:21002") |
| | parser.add_argument( |
| | "--controller-address", type=str, default="http://localhost:21001" |
| | ) |
| | parser.add_argument( |
| | "--conv-template", type=str, default=None, help="Conversation prompt template." |
| | ) |
| | parser.add_argument( |
| | "--model-names", |
| | type=lambda s: s.split(","), |
| | help="Optional display comma separated names", |
| | ) |
| | parser.add_argument("--limit-worker-concurrency", type=int, default=1024) |
| | parser.add_argument("--no-register", action="store_true") |
| |
|
| | parser.add_argument( |
| | "--tokenizer_mode", |
| | type=str, |
| | default="slow", |
| | help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is good for debug and test, |
| | when you want to get best performance, try auto mode""", |
| | ) |
| | parser.add_argument( |
| | "--load_way", |
| | type=str, |
| | default="HF", |
| | help="the way of loading model weights, the default is HF(Huggingface format), llama also supports DS(Deepspeed)", |
| | ) |
| | parser.add_argument( |
| | "--max_total_token_num", |
| | type=int, |
| | default=6000, |
| | help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)", |
| | ) |
| | parser.add_argument( |
| | "--batch_max_tokens", |
| | type=int, |
| | default=None, |
| | help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", |
| | ) |
| | parser.add_argument("--eos_id", type=int, default=2, help="eos stop token id") |
| | parser.add_argument( |
| | "--running_max_req_size", |
| | type=int, |
| | default=1000, |
| | help="the max size for forward requests in the same time", |
| | ) |
| | parser.add_argument( |
| | "--tp", type=int, default=1, help="model tp parral size, the default is 1" |
| | ) |
| | parser.add_argument( |
| | "--max_req_input_len", |
| | type=int, |
| | default=None, |
| | help="the max value for req input tokens num. If None, it will be derived from the config.", |
| | ) |
| | parser.add_argument( |
| | "--max_req_total_len", |
| | type=int, |
| | default=None, |
| | help="the max value for req_input_len + req_output_len. If None, it will be derived from the config.", |
| | ) |
| | parser.add_argument( |
| | "--mode", |
| | type=str, |
| | default=[], |
| | nargs="+", |
| | help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding |
| | | triton_gqa_attention | triton_gqa_flashdecoding] |
| | [triton_int8weight | triton_int4weight | lmdeploy_int4weight | ppl_int4weight], |
| | triton_flashdecoding mode is for long context, current support llama llama2 qwen; |
| | triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; |
| | triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; |
| | ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; |
| | ppl_fp16 mode use ppl fast fp16 decode attention kernel; |
| | triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode use int8 and int4 to store weights; |
| | you need to read source code to make sure the supported detail mode for all models""", |
| | ) |
| | parser.add_argument( |
| | "--trust_remote_code", |
| | action="store_true", |
| | help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", |
| | ) |
| | parser.add_argument( |
| | "--disable_log_stats", |
| | action="store_true", |
| | help="disable logging throughput stats.", |
| | ) |
| | parser.add_argument( |
| | "--log_stats_interval", |
| | type=int, |
| | default=10, |
| | help="log stats interval in second.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--router_token_ratio", |
| | type=float, |
| | default=0.0, |
| | help="token ratio to control router dispatch", |
| | ) |
| | parser.add_argument( |
| | "--router_max_new_token_len", |
| | type=int, |
| | default=1024, |
| | help="the request max new token len for router", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--no_skipping_special_tokens", |
| | action="store_true", |
| | help="whether to skip special tokens when decoding", |
| | ) |
| | parser.add_argument( |
| | "--no_spaces_between_special_tokens", |
| | action="store_true", |
| | help="whether to add spaces between special tokens when decoding", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--splitfuse_mode", action="store_true", help="use splitfuse mode" |
| | ) |
| | parser.add_argument( |
| | "--splitfuse_block_size", type=int, default=256, help="splitfuse block size" |
| | ) |
| | parser.add_argument( |
| | "--prompt_cache_strs", |
| | type=str, |
| | default=[], |
| | nargs="+", |
| | help="""prompt cache strs""", |
| | ) |
| | parser.add_argument( |
| | "--cache_capacity", |
| | type=int, |
| | default=200, |
| | help="cache server capacity for multimodal resources", |
| | ) |
| | parser.add_argument( |
| | "--cache_reserved_ratio", |
| | type=float, |
| | default=0.5, |
| | help="cache server reserved capacity ratio after clear", |
| | ) |
| | parser.add_argument( |
| | "--return_all_prompt_logprobs", |
| | action="store_true", |
| | help="return all prompt tokens logprobs", |
| | ) |
| | parser.add_argument( |
| | "--long_truncation_mode", |
| | type=str, |
| | choices=[None, "head", "center"], |
| | default=None, |
| | help="""use to select the handle way when input token len > max_req_input_len. |
| | None : raise Exception |
| | head : remove some head tokens to make input token len <= max_req_input_len |
| | center : remove some tokens in center loc to make input token len <= max_req_input_len""", |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | if not args.splitfuse_mode: |
| | assert len(args.prompt_cache_strs) == 0 |
| |
|
| | model_config = AutoConfig.from_pretrained(args.model_dir) |
| | context_length = get_context_length(model_config) |
| |
|
| | if args.max_req_input_len is None: |
| | args.max_req_input_len = context_length - 1 |
| | if args.max_req_total_len is None: |
| | args.max_req_total_len = context_length |
| |
|
| | assert args.max_req_input_len < args.max_req_total_len |
| | assert args.max_req_total_len <= args.max_total_token_num |
| |
|
| | if not args.splitfuse_mode: |
| | |
| | if args.batch_max_tokens is None: |
| | batch_max_tokens = int(1 / 6 * args.max_total_token_num) |
| | batch_max_tokens = max(batch_max_tokens, args.max_req_total_len) |
| | args.batch_max_tokens = batch_max_tokens |
| | else: |
| | assert ( |
| | args.batch_max_tokens >= args.max_req_total_len |
| | ), "batch_max_tokens must >= max_req_total_len" |
| | else: |
| | |
| | |
| | if args.batch_max_tokens is None: |
| | batch_max_tokens = int(1 / 6 * args.max_total_token_num) |
| | batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size) |
| | args.batch_max_tokens = batch_max_tokens |
| |
|
| | can_use_ports = alloc_can_use_network_port(num=6 + args.tp) |
| |
|
| | assert can_use_ports is not None, "Can not alloc enough free ports." |
| | ( |
| | router_port, |
| | detokenization_port, |
| | httpserver_port, |
| | visual_port, |
| | cache_port, |
| | nccl_port, |
| | ) = can_use_ports[0:6] |
| | args.nccl_port = nccl_port |
| | model_rpc_ports = can_use_ports[6:] |
| |
|
| | global httpserver_manager |
| | httpserver_manager = HttpServerManager( |
| | args, |
| | router_port=router_port, |
| | cache_port=cache_port, |
| | visual_port=visual_port, |
| | httpserver_port=httpserver_port, |
| | enable_multimodal=False, |
| | ) |
| |
|
| | start_submodule_processes( |
| | start_funcs=[start_router_process, start_detokenization_process], |
| | start_args=[ |
| | (args, router_port, detokenization_port, model_rpc_ports), |
| | (args, detokenization_port, httpserver_port), |
| | ], |
| | ) |
| | worker = LightLLMWorker( |
| | args.controller_address, |
| | args.worker_address, |
| | worker_id, |
| | args.model_dir, |
| | args.model_names, |
| | args.limit_worker_concurrency, |
| | args.no_register, |
| | args.conv_template, |
| | httpserver_manager.tokenizer, |
| | context_length, |
| | ) |
| |
|
| | uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
| |
|