Spaces:
Running
Running
Amlan-109
feat: Initial commit of LocalAI Amlan Edition with premium branding and personalization
750bbe6
| #!/usr/bin/env python3 | |
| import asyncio | |
| from concurrent import futures | |
| import argparse | |
| import signal | |
| import sys | |
| import os | |
| from typing import List | |
| import time | |
| import backend_pb2 | |
| import backend_pb2_grpc | |
| import grpc | |
| from mlx_vlm import load, generate, stream_generate | |
| from mlx_vlm.prompt_utils import apply_chat_template | |
| from mlx_vlm.utils import load_config, load_image | |
| import mlx.core as mx | |
| import base64 | |
| import io | |
| from PIL import Image | |
| import tempfile | |
| def is_float(s): | |
| """Check if a string can be converted to float.""" | |
| try: | |
| float(s) | |
| return True | |
| except ValueError: | |
| return False | |
| def is_int(s): | |
| """Check if a string can be converted to int.""" | |
| try: | |
| int(s) | |
| return True | |
| except ValueError: | |
| return False | |
| _ONE_DAY_IN_SECONDS = 60 * 60 * 24 | |
| # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 | |
| MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) | |
| # Implement the BackendServicer class with the service methods | |
| class BackendServicer(backend_pb2_grpc.BackendServicer): | |
| """ | |
| A gRPC servicer that implements the Backend service defined in backend.proto. | |
| """ | |
| def Health(self, request, context): | |
| """ | |
| Returns a health check message. | |
| Args: | |
| request: The health check request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Reply: The health check reply. | |
| """ | |
| return backend_pb2.Reply(message=bytes("OK", 'utf-8')) | |
| async def LoadModel(self, request, context): | |
| """ | |
| Loads a multimodal vision-language model using MLX-VLM. | |
| Args: | |
| request: The load model request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Result: The load model result. | |
| """ | |
| try: | |
| print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr) | |
| print(f"Request: {request}", file=sys.stderr) | |
| # Parse options like in the diffusers backend | |
| options = request.Options | |
| self.options = {} | |
| # The options are a list of strings in this form optname:optvalue | |
| # We store all the options in a dict for later use | |
| for opt in options: | |
| if ":" not in opt: | |
| continue | |
| key, value = opt.split(":", 1) # Split only on first colon to handle values with colons | |
| if is_float(value): | |
| value = float(value) | |
| elif is_int(value): | |
| value = int(value) | |
| elif value.lower() in ["true", "false"]: | |
| value = value.lower() == "true" | |
| self.options[key] = value | |
| print(f"Options: {self.options}", file=sys.stderr) | |
| # Load model and processor using MLX-VLM | |
| # mlx-vlm load function returns (model, processor) instead of (model, tokenizer) | |
| self.model, self.processor = load(request.Model) | |
| # Load model config for chat template support | |
| self.config = load_config(request.Model) | |
| except Exception as err: | |
| print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr) | |
| return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}") | |
| print("MLX-VLM model loaded successfully", file=sys.stderr) | |
| return backend_pb2.Result(message="MLX-VLM model loaded successfully", success=True) | |
| async def Predict(self, request, context): | |
| """ | |
| Generates text based on the given prompt and sampling parameters using MLX-VLM with multimodal support. | |
| Args: | |
| request: The predict request. | |
| context: The gRPC context. | |
| Returns: | |
| backend_pb2.Reply: The predict result. | |
| """ | |
| temp_files = [] | |
| try: | |
| # Process images and audios from request | |
| image_paths = [] | |
| audio_paths = [] | |
| # Process images | |
| if request.Images: | |
| for img_data in request.Images: | |
| img_path = self.load_image_from_base64(img_data) | |
| if img_path: | |
| image_paths.append(img_path) | |
| temp_files.append(img_path) | |
| # Process audios | |
| if request.Audios: | |
| for audio_data in request.Audios: | |
| audio_path = self.load_audio_from_base64(audio_data) | |
| if audio_path: | |
| audio_paths.append(audio_path) | |
| temp_files.append(audio_path) | |
| # Prepare the prompt with multimodal information | |
| prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths)) | |
| # Build generation parameters using request attributes and options | |
| max_tokens, generation_params = self._build_generation_params(request) | |
| print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr) | |
| print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr) | |
| # Generate text using MLX-VLM with multimodal inputs | |
| response = generate( | |
| model=self.model, | |
| processor=self.processor, | |
| prompt=prompt, | |
| image=image_paths if image_paths else None, | |
| audio=audio_paths if audio_paths else None, | |
| max_tokens=max_tokens, | |
| temperature=generation_params.get('temp', 0.6), | |
| top_p=generation_params.get('top_p', 1.0), | |
| verbose=False | |
| ) | |
| return backend_pb2.Reply(message=bytes(response, encoding='utf-8')) | |
| except Exception as e: | |
| print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr) | |
| context.set_code(grpc.StatusCode.INTERNAL) | |
| context.set_details(f"Generation failed: {str(e)}") | |
| return backend_pb2.Reply(message=bytes("", encoding='utf-8')) | |
| finally: | |
| # Clean up temporary files | |
| self.cleanup_temp_files(temp_files) | |
| def Embedding(self, request, context): | |
| """ | |
| A gRPC method that calculates embeddings for a given sentence. | |
| Note: MLX-VLM doesn't support embeddings directly. This method returns an error. | |
| Args: | |
| request: An EmbeddingRequest object that contains the request parameters. | |
| context: A grpc.ServicerContext object that provides information about the RPC. | |
| Returns: | |
| An EmbeddingResult object that contains the calculated embeddings. | |
| """ | |
| print("Embeddings not supported in MLX-VLM backend", file=sys.stderr) | |
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |
| context.set_details("Embeddings are not supported in the MLX-VLM backend.") | |
| return backend_pb2.EmbeddingResult() | |
| async def PredictStream(self, request, context): | |
| """ | |
| Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support. | |
| Args: | |
| request: The predict stream request. | |
| context: The gRPC context. | |
| Yields: | |
| backend_pb2.Reply: Streaming predict results. | |
| """ | |
| temp_files = [] | |
| try: | |
| # Process images and audios from request | |
| image_paths = [] | |
| audio_paths = [] | |
| # Process images | |
| if request.Images: | |
| for img_data in request.Images: | |
| img_path = self.load_image_from_base64(img_data) | |
| if img_path: | |
| image_paths.append(img_path) | |
| temp_files.append(img_path) | |
| # Process audios | |
| if request.Audios: | |
| for audio_data in request.Audios: | |
| audio_path = self.load_audio_from_base64(audio_data) | |
| if audio_path: | |
| audio_paths.append(audio_path) | |
| temp_files.append(audio_path) | |
| # Prepare the prompt with multimodal information | |
| prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths)) | |
| # Build generation parameters using request attributes and options | |
| max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512) | |
| print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr) | |
| print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr) | |
| # Stream text generation using MLX-VLM with multimodal inputs | |
| for response in stream_generate( | |
| model=self.model, | |
| processor=self.processor, | |
| prompt=prompt, | |
| image=image_paths if image_paths else None, | |
| audio=audio_paths if audio_paths else None, | |
| max_tokens=max_tokens, | |
| temperature=generation_params.get('temp', 0.6), | |
| top_p=generation_params.get('top_p', 1.0), | |
| ): | |
| yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) | |
| except Exception as e: | |
| print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr) | |
| context.set_code(grpc.StatusCode.INTERNAL) | |
| context.set_details(f"Streaming generation failed: {str(e)}") | |
| yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) | |
| finally: | |
| # Clean up temporary files | |
| self.cleanup_temp_files(temp_files) | |
| def _prepare_prompt(self, request, num_images=0, num_audios=0): | |
| """ | |
| Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs. | |
| Args: | |
| request: The gRPC request containing prompt and message information. | |
| num_images: Number of images in the request. | |
| num_audios: Number of audio files in the request. | |
| Returns: | |
| str: The prepared prompt. | |
| """ | |
| # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template | |
| if not request.Prompt and request.UseTokenizerTemplate and request.Messages: | |
| # Convert gRPC messages to the format expected by apply_chat_template | |
| messages = [] | |
| for msg in request.Messages: | |
| messages.append({"role": msg.role, "content": msg.content}) | |
| # Use mlx-vlm's apply_chat_template which handles multimodal inputs | |
| prompt = apply_chat_template( | |
| self.processor, | |
| self.config, | |
| messages, | |
| num_images=num_images, | |
| num_audios=num_audios | |
| ) | |
| return prompt | |
| elif request.Prompt: | |
| # If we have a direct prompt but also have images/audio, we need to format it properly | |
| if num_images > 0 or num_audios > 0: | |
| # Create a simple message structure for multimodal prompt | |
| messages = [{"role": "user", "content": request.Prompt}] | |
| prompt = apply_chat_template( | |
| self.processor, | |
| self.config, | |
| messages, | |
| num_images=num_images, | |
| num_audios=num_audios | |
| ) | |
| return prompt | |
| else: | |
| return request.Prompt | |
| else: | |
| # Fallback to empty prompt with multimodal template if we have media | |
| if num_images > 0 or num_audios > 0: | |
| messages = [{"role": "user", "content": ""}] | |
| prompt = apply_chat_template( | |
| self.processor, | |
| self.config, | |
| messages, | |
| num_images=num_images, | |
| num_audios=num_audios | |
| ) | |
| return prompt | |
| else: | |
| return "" | |
| def _build_generation_params(self, request, default_max_tokens=200): | |
| """ | |
| Build generation parameters from request attributes and options for MLX-VLM. | |
| Args: | |
| request: The gRPC request. | |
| default_max_tokens: Default max_tokens if not specified. | |
| Returns: | |
| tuple: (max_tokens, generation_params dict) | |
| """ | |
| # Extract max_tokens | |
| max_tokens = getattr(request, 'Tokens', default_max_tokens) | |
| if max_tokens == 0: | |
| max_tokens = default_max_tokens | |
| # Extract generation parameters from request attributes | |
| temp = getattr(request, 'Temperature', 0.0) | |
| if temp == 0.0: | |
| temp = 0.6 # Default temperature | |
| top_p = getattr(request, 'TopP', 0.0) | |
| if top_p == 0.0: | |
| top_p = 1.0 # Default top_p | |
| # Initialize generation parameters for MLX-VLM | |
| generation_params = { | |
| 'temp': temp, | |
| 'top_p': top_p, | |
| } | |
| # Add seed if specified | |
| seed = getattr(request, 'Seed', 0) | |
| if seed != 0: | |
| mx.random.seed(seed) | |
| # Override with options if available | |
| if hasattr(self, 'options'): | |
| # Max tokens from options | |
| if 'max_tokens' in self.options: | |
| max_tokens = self.options['max_tokens'] | |
| # Generation parameters from options | |
| param_option_mapping = { | |
| 'temp': 'temp', | |
| 'temperature': 'temp', # alias | |
| 'top_p': 'top_p', | |
| } | |
| for option_key, param_key in param_option_mapping.items(): | |
| if option_key in self.options: | |
| generation_params[param_key] = self.options[option_key] | |
| # Handle seed from options | |
| if 'seed' in self.options: | |
| mx.random.seed(self.options['seed']) | |
| return max_tokens, generation_params | |
| def load_image_from_base64(self, image_data: str): | |
| """ | |
| Load an image from base64 encoded data. | |
| Args: | |
| image_data (str): Base64 encoded image data. | |
| Returns: | |
| PIL.Image or str: The loaded image or path to the image. | |
| """ | |
| try: | |
| decoded_data = base64.b64decode(image_data) | |
| image = Image.open(io.BytesIO(decoded_data)) | |
| # Save to temporary file for mlx-vlm | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
| image.save(tmp_file.name, format='JPEG') | |
| return tmp_file.name | |
| except Exception as e: | |
| print(f"Error loading image from base64: {e}", file=sys.stderr) | |
| return None | |
| def load_audio_from_base64(self, audio_data: str): | |
| """ | |
| Load audio from base64 encoded data. | |
| Args: | |
| audio_data (str): Base64 encoded audio data. | |
| Returns: | |
| str: Path to the loaded audio file. | |
| """ | |
| try: | |
| decoded_data = base64.b64decode(audio_data) | |
| # Save to temporary file for mlx-vlm | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: | |
| tmp_file.write(decoded_data) | |
| return tmp_file.name | |
| except Exception as e: | |
| print(f"Error loading audio from base64: {e}", file=sys.stderr) | |
| return None | |
| def cleanup_temp_files(self, file_paths: List[str]): | |
| """ | |
| Clean up temporary files. | |
| Args: | |
| file_paths (List[str]): List of file paths to clean up. | |
| """ | |
| for file_path in file_paths: | |
| try: | |
| if file_path and os.path.exists(file_path): | |
| os.remove(file_path) | |
| except Exception as e: | |
| print(f"Error removing temporary file {file_path}: {e}", file=sys.stderr) | |
| async def serve(address): | |
| # Start asyncio gRPC server | |
| server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), | |
| options=[ | |
| ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB | |
| ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB | |
| ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB | |
| ]) | |
| # Add the servicer to the server | |
| backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) | |
| # Bind the server to the address | |
| server.add_insecure_port(address) | |
| # Gracefully shutdown the server on SIGTERM or SIGINT | |
| loop = asyncio.get_event_loop() | |
| for sig in (signal.SIGINT, signal.SIGTERM): | |
| loop.add_signal_handler( | |
| sig, lambda: asyncio.ensure_future(server.stop(5)) | |
| ) | |
| # Start the server | |
| await server.start() | |
| print("Server started. Listening on: " + address, file=sys.stderr) | |
| # Wait for the server to be terminated | |
| await server.wait_for_termination() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run the gRPC server.") | |
| parser.add_argument( | |
| "--addr", default="localhost:50051", help="The address to bind the server to." | |
| ) | |
| args = parser.parse_args() | |
| asyncio.run(serve(args.addr)) | |