Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| import asyncio | |
| from concurrent import futures | |
| import argparse | |
| import signal | |
| import sys | |
| import os | |
| from typing import List | |
| from PIL import Image | |
| import backend_pb2 | |
| import backend_pb2_grpc | |
| import grpc | |
| from vllm.engine.arg_utils import AsyncEngineArgs | |
| from vllm.engine.async_llm_engine import AsyncLLMEngine | |
| from vllm.sampling_params import SamplingParams | |
| from vllm.utils import random_uuid | |
| from vllm.transformers_utils.tokenizer import get_tokenizer | |
| from vllm.multimodal.utils import fetch_image | |
| from vllm.assets.video import VideoAsset | |
| import base64 | |
| import io | |
| _ONE_DAY_IN_SECONDS = 60 * 60 * 24 | |
| # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 | |
| MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) | |
| # Implement the BackendServicer class with the service methods | |
| class BackendServicer(backend_pb2_grpc.BackendServicer): | |
| """ | |
| A gRPC servicer that implements the Backend service defined in backend.proto. | |
| """ | |
| def generate(self,prompt, max_new_tokens): | |
| """ | |
| Generates text based on the given prompt and maximum number of new tokens. | |
| Args: | |
| prompt (str): The prompt to generate text from. | |
| max_new_tokens (int): The maximum number of new tokens to generate. | |
| Returns: | |
| str: The generated text. | |
| """ | |
| self.generator.end_beam_search() | |
| # Tokenizing the input | |
| ids = self.generator.tokenizer.encode(prompt) | |
| self.generator.gen_begin_reuse(ids) | |
| initial_len = self.generator.sequence[0].shape[0] | |
| has_leading_space = False | |
| decoded_text = '' | |
| for i in range(max_new_tokens): | |
| token = self.generator.gen_single_token() | |
| if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): | |
| has_leading_space = True | |
| decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) | |
| if has_leading_space: | |
| decoded_text = ' ' + decoded_text | |
| if token.item() == self.generator.tokenizer.eos_token_id: | |
| break | |
| return decoded_text | |
| def Health(self, request, context): | |
| """ | |
| Returns a health check message. | |
| Args: | |
| request: The health check request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Reply: The health check reply. | |
| """ | |
| return backend_pb2.Reply(message=bytes("OK", 'utf-8')) | |
| async def LoadModel(self, request, context): | |
| """ | |
| Loads a language model. | |
| Args: | |
| request: The load model request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Result: The load model result. | |
| """ | |
| engine_args = AsyncEngineArgs( | |
| model=request.Model, | |
| ) | |
| if request.Quantization != "": | |
| engine_args.quantization = request.Quantization | |
| if request.LoadFormat != "": | |
| engine_args.load_format = request.LoadFormat | |
| if request.GPUMemoryUtilization != 0: | |
| engine_args.gpu_memory_utilization = request.GPUMemoryUtilization | |
| if request.TrustRemoteCode: | |
| engine_args.trust_remote_code = request.TrustRemoteCode | |
| if request.EnforceEager: | |
| engine_args.enforce_eager = request.EnforceEager | |
| if request.TensorParallelSize: | |
| engine_args.tensor_parallel_size = request.TensorParallelSize | |
| if request.SwapSpace != 0: | |
| engine_args.swap_space = request.SwapSpace | |
| if request.MaxModelLen != 0: | |
| engine_args.max_model_len = request.MaxModelLen | |
| try: | |
| self.llm = AsyncLLMEngine.from_engine_args(engine_args) | |
| except Exception as err: | |
| print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) | |
| return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") | |
| try: | |
| engine_model_config = await self.llm.get_model_config() | |
| self.tokenizer = get_tokenizer( | |
| engine_model_config.tokenizer, | |
| tokenizer_mode=engine_model_config.tokenizer_mode, | |
| trust_remote_code=engine_model_config.trust_remote_code, | |
| truncation_side="left", | |
| ) | |
| except Exception as err: | |
| return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") | |
| print("Model loaded successfully", file=sys.stderr) | |
| return backend_pb2.Result(message="Model loaded successfully", success=True) | |
| async def Predict(self, request, context): | |
| """ | |
| Generates text based on the given prompt and sampling parameters. | |
| Args: | |
| request: The predict request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Reply: The predict result. | |
| """ | |
| gen = self._predict(request, context, streaming=False) | |
| res = await gen.__anext__() | |
| return res | |
| def Embedding(self, request, context): | |
| """ | |
| A gRPC method that calculates embeddings for a given sentence. | |
| Args: | |
| request: An EmbeddingRequest object that contains the request parameters. | |
| context: A grpc.ServicerContext object that provides information about the RPC. | |
| Returns: | |
| An EmbeddingResult object that contains the calculated embeddings. | |
| """ | |
| print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) | |
| outputs = self.model.encode(request.Embeddings) | |
| # Check if we have one result at least | |
| if len(outputs) == 0: | |
| context.set_code(grpc.StatusCode.INVALID_ARGUMENT) | |
| context.set_details("No embeddings were calculated.") | |
| return backend_pb2.EmbeddingResult() | |
| return backend_pb2.EmbeddingResult(embeddings=outputs[0].outputs.embedding) | |
| async def PredictStream(self, request, context): | |
| """ | |
| Generates text based on the given prompt and sampling parameters, and streams the results. | |
| Args: | |
| request: The predict stream request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Result: The predict stream result. | |
| """ | |
| iterations = self._predict(request, context, streaming=True) | |
| try: | |
| async for iteration in iterations: | |
| yield iteration | |
| finally: | |
| await iterations.aclose() | |
| async def _predict(self, request, context, streaming=False): | |
| # Build sampling parameters | |
| sampling_params = SamplingParams(top_p=0.9, max_tokens=200) | |
| if request.TopP != 0: | |
| sampling_params.top_p = request.TopP | |
| if request.Tokens > 0: | |
| sampling_params.max_tokens = request.Tokens | |
| if request.Temperature != 0: | |
| sampling_params.temperature = request.Temperature | |
| if request.TopK != 0: | |
| sampling_params.top_k = request.TopK | |
| if request.PresencePenalty != 0: | |
| sampling_params.presence_penalty = request.PresencePenalty | |
| if request.FrequencyPenalty != 0: | |
| sampling_params.frequency_penalty = request.FrequencyPenalty | |
| if request.StopPrompts: | |
| sampling_params.stop = request.StopPrompts | |
| if request.IgnoreEOS: | |
| sampling_params.ignore_eos = request.IgnoreEOS | |
| if request.Seed != 0: | |
| sampling_params.seed = request.Seed | |
| # Extract image paths and process images | |
| prompt = request.Prompt | |
| image_paths = request.Images | |
| image_data = [self.load_image(img_path) for img_path in image_paths] | |
| videos_path = request.Videos | |
| video_data = [self.load_video(video_path) for video_path in videos_path] | |
| # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template | |
| if not request.Prompt and request.UseTokenizerTemplate and request.Messages: | |
| prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) | |
| # Generate text using the LLM engine | |
| request_id = random_uuid() | |
| print(f"Generating text with request_id: {request_id}", file=sys.stderr) | |
| multi_modal_data = {} | |
| if image_data: | |
| multi_modal_data["image"] = image_data | |
| if video_data: | |
| multi_modal_data["video"] = video_data | |
| outputs = self.llm.generate( | |
| { | |
| "prompt": prompt, | |
| "multi_modal_data": multi_modal_data if multi_modal_data else None, | |
| }, | |
| sampling_params=sampling_params, | |
| request_id=request_id, | |
| ) | |
| # Stream the results | |
| generated_text = "" | |
| try: | |
| async for request_output in outputs: | |
| iteration_text = request_output.outputs[0].text | |
| if streaming: | |
| # Remove text already sent as vllm concatenates the text from previous yields | |
| delta_iteration_text = iteration_text.removeprefix(generated_text) | |
| # Send the partial result | |
| yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8')) | |
| # Keep track of text generated | |
| generated_text = iteration_text | |
| finally: | |
| await outputs.aclose() | |
| # If streaming, we already sent everything | |
| if streaming: | |
| return | |
| # Remove the image files from /tmp folder | |
| for img_path in image_paths: | |
| try: | |
| os.remove(img_path) | |
| except Exception as e: | |
| print(f"Error removing image file: {img_path}, {e}", file=sys.stderr) | |
| # Sending the final generated text | |
| yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) | |
| def load_image(self, image_path: str): | |
| """ | |
| Load an image from the given file path or base64 encoded data. | |
| Args: | |
| image_path (str): The path to the image file or base64 encoded data. | |
| Returns: | |
| Image: The loaded image. | |
| """ | |
| try: | |
| image_data = base64.b64decode(image_path) | |
| image = Image.open(io.BytesIO(image_data)) | |
| return image | |
| except Exception as e: | |
| print(f"Error loading image {image_path}: {e}", file=sys.stderr) | |
| return None | |
| def load_video(self, video_path: str): | |
| """ | |
| Load a video from the given file path. | |
| Args: | |
| video_path (str): The path to the image file. | |
| Returns: | |
| Video: The loaded video. | |
| """ | |
| try: | |
| timestamp = str(int(time.time() * 1000)) # Generate timestamp | |
| p = f"/tmp/vl-{timestamp}.data" # Use timestamp in filename | |
| with open(p, "wb") as f: | |
| f.write(base64.b64decode(video_path)) | |
| video = VideoAsset(name=p).np_ndarrays | |
| os.remove(p) | |
| return video | |
| except Exception as e: | |
| print(f"Error loading video {video_path}: {e}", file=sys.stderr) | |
| return None | |
| async def serve(address): | |
| # Start asyncio gRPC server | |
| server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) | |
| # Add the servicer to the server | |
| backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) | |
| # Bind the server to the address | |
| server.add_insecure_port(address) | |
| # Gracefully shutdown the server on SIGTERM or SIGINT | |
| loop = asyncio.get_event_loop() | |
| for sig in (signal.SIGINT, signal.SIGTERM): | |
| loop.add_signal_handler( | |
| sig, lambda: asyncio.ensure_future(server.stop(5)) | |
| ) | |
| # Start the server | |
| await server.start() | |
| print("Server started. Listening on: " + address, file=sys.stderr) | |
| # Wait for the server to be terminated | |
| await server.wait_for_termination() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run the gRPC server.") | |
| parser.add_argument( | |
| "--addr", default="localhost:50051", help="The address to bind the server to." | |
| ) | |
| args = parser.parse_args() | |
| asyncio.run(serve(args.addr)) |