Spaces:
Running
Running
Amlan-109
feat: Initial commit of LocalAI Amlan Edition with premium branding and personalization
750bbe6
| #!/usr/bin/env python3 | |
| import asyncio | |
| from concurrent import futures | |
| import argparse | |
| import signal | |
| import sys | |
| import os | |
| from typing import List | |
| import time | |
| import backend_pb2 | |
| import backend_pb2_grpc | |
| import grpc | |
| from mlx_lm import load, generate, stream_generate | |
| from mlx_lm.sample_utils import make_sampler | |
| from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache | |
| import mlx.core as mx | |
| import base64 | |
| import io | |
| from mlx_cache import ThreadSafeLRUPromptCache | |
| _ONE_DAY_IN_SECONDS = 60 * 60 * 24 | |
| # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 | |
| MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) | |
| def is_float(s): | |
| """Check if a string can be converted to float.""" | |
| try: | |
| float(s) | |
| return True | |
| except ValueError: | |
| return False | |
| def is_int(s): | |
| """Check if a string can be converted to int.""" | |
| try: | |
| int(s) | |
| return True | |
| except ValueError: | |
| return False | |
| # Implement the BackendServicer class with the service methods | |
| class BackendServicer(backend_pb2_grpc.BackendServicer): | |
| """ | |
| A gRPC servicer that implements the Backend service defined in backend.proto. | |
| """ | |
| def Health(self, request, context): | |
| """ | |
| Returns a health check message. | |
| Args: | |
| request: The health check request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Reply: The health check reply. | |
| """ | |
| return backend_pb2.Reply(message=bytes("OK", 'utf-8')) | |
| async def LoadModel(self, request, context): | |
| """ | |
| Loads a language model using MLX. | |
| Args: | |
| request: The load model request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Result: The load model result. | |
| """ | |
| try: | |
| print(f"Loading MLX model: {request.Model}", file=sys.stderr) | |
| print(f"Request: {request}", file=sys.stderr) | |
| # Parse options like in the diffusers backend | |
| options = request.Options | |
| self.options = {} | |
| # The options are a list of strings in this form optname:optvalue | |
| # We store all the options in a dict for later use | |
| for opt in options: | |
| if ":" not in opt: | |
| continue | |
| key, value = opt.split(":", 1) # Split only on first colon to handle values with colons | |
| # Convert numeric values to appropriate types | |
| if is_float(value): | |
| value = float(value) | |
| elif is_int(value): | |
| value = int(value) | |
| elif value.lower() in ["true", "false"]: | |
| value = value.lower() == "true" | |
| self.options[key] = value | |
| print(f"Options: {self.options}", file=sys.stderr) | |
| # Build tokenizer config for MLX using options | |
| tokenizer_config = {} | |
| # Handle trust_remote_code from request or options | |
| if request.TrustRemoteCode or self.options.get("trust_remote_code", False): | |
| tokenizer_config["trust_remote_code"] = True | |
| # Handle EOS token from options | |
| if "eos_token" in self.options: | |
| tokenizer_config["eos_token"] = self.options["eos_token"] | |
| # Handle other tokenizer config options | |
| for key in ["pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]: | |
| if key in self.options: | |
| tokenizer_config[key] = self.options[key] | |
| # Load model and tokenizer using MLX | |
| if tokenizer_config: | |
| print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr) | |
| self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) | |
| else: | |
| self.model, self.tokenizer = load(request.Model) | |
| # Initialize thread-safe LRU prompt cache for efficient generation | |
| max_cache_entries = self.options.get("max_cache_entries", 10) | |
| self.max_kv_size = self.options.get("max_kv_size", None) | |
| self.model_key = request.Model | |
| self.lru_cache = ThreadSafeLRUPromptCache( | |
| max_size=max_cache_entries, | |
| can_trim_fn=can_trim_prompt_cache, | |
| trim_fn=trim_prompt_cache, | |
| ) | |
| except Exception as err: | |
| print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr) | |
| return backend_pb2.Result(success=False, message=f"Error loading MLX model: {err}") | |
| print("MLX model loaded successfully", file=sys.stderr) | |
| return backend_pb2.Result(message="MLX model loaded successfully", success=True) | |
| async def Predict(self, request, context): | |
| """ | |
| Generates text based on the given prompt and sampling parameters using MLX. | |
| Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. | |
| Args: | |
| request: The predict request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Reply: The predict result. | |
| """ | |
| prompt_cache = None | |
| cache_key = None | |
| try: | |
| # Prepare the prompt and tokenize for cache key | |
| prompt_text = self._prepare_prompt(request) | |
| cache_key = self._get_tokens_from_prompt(prompt_text) | |
| # Fetch nearest cache (exact, shorter prefix, or create new) | |
| prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( | |
| self.model_key, cache_key | |
| ) | |
| if prompt_cache is None: | |
| prompt_cache = make_prompt_cache(self.model, self.max_kv_size) | |
| remaining_tokens = cache_key | |
| # Build generation parameters using request attributes and options | |
| max_tokens, sampler_params = self._build_generation_params(request) | |
| print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) | |
| # Create sampler with parameters | |
| sampler = make_sampler(**sampler_params) | |
| # Use stream_generate to track generated tokens for cache key | |
| generated_text = [] | |
| for response in stream_generate( | |
| self.model, | |
| self.tokenizer, | |
| prompt=remaining_tokens if remaining_tokens else cache_key, | |
| max_tokens=max_tokens, | |
| sampler=sampler, | |
| prompt_cache=prompt_cache, | |
| ): | |
| generated_text.append(response.text) | |
| cache_key.append(response.token) | |
| # Insert completed cache | |
| self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) | |
| return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8')) | |
| except Exception as e: | |
| print(f"Error in MLX Predict: {e}", file=sys.stderr) | |
| context.set_code(grpc.StatusCode.INTERNAL) | |
| context.set_details(f"Generation failed: {str(e)}") | |
| return backend_pb2.Reply(message=bytes("", encoding='utf-8')) | |
| def Embedding(self, request, context): | |
| """ | |
| A gRPC method that calculates embeddings for a given sentence. | |
| Note: MLX-LM doesn't support embeddings directly. This method returns an error. | |
| Args: | |
| request: An EmbeddingRequest object that contains the request parameters. | |
| context: A grpc.ServicerContext object that provides information about the RPC. | |
| Returns: | |
| An EmbeddingResult object that contains the calculated embeddings. | |
| """ | |
| print("Embeddings not supported in MLX backend", file=sys.stderr) | |
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |
| context.set_details("Embeddings are not supported in the MLX backend.") | |
| return backend_pb2.EmbeddingResult() | |
| async def PredictStream(self, request, context): | |
| """ | |
| Generates text based on the given prompt and sampling parameters, and streams the results using MLX. | |
| Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. | |
| Args: | |
| request: The predict stream request. | |
| context: The gRPC context. | |
| Yields: | |
| backend_pb2.Reply: Streaming predict results. | |
| """ | |
| prompt_cache = None | |
| cache_key = None | |
| try: | |
| # Prepare the prompt and tokenize for cache key | |
| prompt_text = self._prepare_prompt(request) | |
| cache_key = self._get_tokens_from_prompt(prompt_text) | |
| # Fetch nearest cache (exact, shorter prefix, or create new) | |
| prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( | |
| self.model_key, cache_key | |
| ) | |
| if prompt_cache is None: | |
| prompt_cache = make_prompt_cache(self.model, self.max_kv_size) | |
| remaining_tokens = cache_key | |
| # Build generation parameters using request attributes and options | |
| max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) | |
| print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) | |
| # Create sampler with parameters | |
| sampler = make_sampler(**sampler_params) | |
| # Stream text generation using MLX with proper parameters | |
| for response in stream_generate( | |
| self.model, | |
| self.tokenizer, | |
| prompt=remaining_tokens if remaining_tokens else cache_key, | |
| max_tokens=max_tokens, | |
| sampler=sampler, | |
| prompt_cache=prompt_cache, | |
| ): | |
| cache_key.append(response.token) | |
| yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) | |
| except Exception as e: | |
| print(f"Error in MLX PredictStream: {e}", file=sys.stderr) | |
| context.set_code(grpc.StatusCode.INTERNAL) | |
| context.set_details(f"Streaming generation failed: {str(e)}") | |
| yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) | |
| finally: | |
| # Always insert cache, even on interruption | |
| if prompt_cache is not None and cache_key is not None: | |
| try: | |
| self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) | |
| except Exception as e: | |
| print(f"Error inserting cache: {e}", file=sys.stderr) | |
| def _prepare_prompt(self, request): | |
| """ | |
| Prepare the prompt for MLX generation, handling chat templates if needed. | |
| Args: | |
| request: The gRPC request containing prompt and message information. | |
| Returns: | |
| str: The prepared prompt. | |
| """ | |
| # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template | |
| if not request.Prompt and request.UseTokenizerTemplate and request.Messages: | |
| # Convert gRPC messages to the format expected by apply_chat_template | |
| messages = [] | |
| for msg in request.Messages: | |
| messages.append({"role": msg.role, "content": msg.content}) | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return prompt | |
| else: | |
| return request.Prompt | |
| def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: | |
| """ | |
| Tokenize prompt text for cache key generation. | |
| Args: | |
| prompt_text: The prompt string to tokenize. | |
| Returns: | |
| List[int]: List of token IDs. | |
| """ | |
| tokens = self.tokenizer.encode(prompt_text) | |
| if hasattr(tokens, 'tolist'): | |
| return tokens.tolist() | |
| return list(tokens) | |
| def _build_generation_params(self, request, default_max_tokens=200): | |
| """ | |
| Build generation parameters from request attributes and options. | |
| Args: | |
| request: The gRPC request. | |
| default_max_tokens: Default max_tokens if not specified. | |
| Returns: | |
| tuple: (max_tokens, sampler_params dict) | |
| """ | |
| # Extract max_tokens | |
| max_tokens = getattr(request, 'Tokens', default_max_tokens) | |
| if max_tokens == 0: | |
| max_tokens = default_max_tokens | |
| # Extract sampler parameters from request attributes | |
| temp = getattr(request, 'Temperature', 0.0) | |
| if temp == 0.0: | |
| temp = 0.6 # Default temperature | |
| top_p = getattr(request, 'TopP', 0.0) | |
| if top_p == 0.0: | |
| top_p = 1.0 # Default top_p | |
| min_p = getattr(request, 'MinP', 0.0) | |
| # min_p default of 0.0 means disabled (no filtering) | |
| top_k = getattr(request, 'TopK', 0) | |
| # top_k default of 0 means disabled (no filtering) | |
| # Initialize sampler parameters | |
| sampler_params = { | |
| 'temp': temp, | |
| 'top_p': top_p, | |
| 'min_p': min_p, | |
| 'top_k': top_k, | |
| 'xtc_threshold': 0.0, | |
| 'xtc_probability': 0.0, | |
| } | |
| # Add seed if specified | |
| seed = getattr(request, 'Seed', 0) | |
| if seed != 0: | |
| mx.random.seed(seed) | |
| # Override with options if available | |
| if hasattr(self, 'options'): | |
| # Max tokens from options | |
| if 'max_tokens' in self.options: | |
| max_tokens = self.options['max_tokens'] | |
| # Sampler parameters from options | |
| sampler_option_mapping = { | |
| 'temp': 'temp', | |
| 'temperature': 'temp', # alias | |
| 'top_p': 'top_p', | |
| 'min_p': 'min_p', | |
| 'top_k': 'top_k', | |
| 'xtc_threshold': 'xtc_threshold', | |
| 'xtc_probability': 'xtc_probability', | |
| } | |
| for option_key, param_key in sampler_option_mapping.items(): | |
| if option_key in self.options: | |
| sampler_params[param_key] = self.options[option_key] | |
| # Handle seed from options | |
| if 'seed' in self.options: | |
| mx.random.seed(self.options['seed']) | |
| # Special tokens for XTC sampling (if tokenizer has eos_token_ids) | |
| xtc_special_tokens = [] | |
| if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids: | |
| xtc_special_tokens = list(self.tokenizer.eos_token_ids) | |
| elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: | |
| xtc_special_tokens = [self.tokenizer.eos_token_id] | |
| # Add newline token if available | |
| try: | |
| newline_tokens = self.tokenizer.encode("\n") | |
| xtc_special_tokens.extend(newline_tokens) | |
| except: | |
| pass # Skip if encoding fails | |
| sampler_params['xtc_special_tokens'] = xtc_special_tokens | |
| return max_tokens, sampler_params | |
| async def serve(address): | |
| # Start asyncio gRPC server | |
| server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), | |
| options=[ | |
| ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB | |
| ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB | |
| ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB | |
| ]) | |
| # Add the servicer to the server | |
| backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) | |
| # Bind the server to the address | |
| server.add_insecure_port(address) | |
| # Gracefully shutdown the server on SIGTERM or SIGINT | |
| loop = asyncio.get_event_loop() | |
| for sig in (signal.SIGINT, signal.SIGTERM): | |
| loop.add_signal_handler( | |
| sig, lambda: asyncio.ensure_future(server.stop(5)) | |
| ) | |
| # Start the server | |
| await server.start() | |
| print("Server started. Listening on: " + address, file=sys.stderr) | |
| # Wait for the server to be terminated | |
| await server.wait_for_termination() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run the gRPC server.") | |
| parser.add_argument( | |
| "--addr", default="localhost:50051", help="The address to bind the server to." | |
| ) | |
| args = parser.parse_args() | |
| asyncio.run(serve(args.addr)) | |