from typing import Dict, List, Any import torch from transformers import AutoTokenizer, AutoModelForCausalLM import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the model and tokenizer when the endpoint starts. Args: path (str): Path to the model files """ logger.info(f"Loading model from {path}") # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(path) # Try to load without quantization first try: self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, load_in_8bit=False, load_in_4bit=False ) except Exception as e: logger.warning(f"Failed to load without quantization: {e}") # Fallback: try with different settings self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, use_safetensors=True ) # Set pad token if it doesn't exist if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token logger.info("Model loaded successfully") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the inference request. Args: data (Dict[str, Any]): Request data containing: - inputs (str): The input text/prompt - parameters (dict, optional): Generation parameters - max_new_tokens (int): Maximum tokens to generate (default: 256) - temperature (float): Sampling temperature (default: 0.7) - top_p (float): Top-p sampling (default: 0.9) - do_sample (bool): Whether to use sampling (default: True) - repetition_penalty (float): Repetition penalty (default: 1.1) - return_full_text (bool): Return full text including input (default: False) Returns: List[Dict[str, Any]]: Generated text response """ try: # Extract inputs inputs = data.get("inputs", "") if not inputs: return [{"error": "No input text provided"}] # Extract generation parameters parameters = data.get("parameters", {}) max_new_tokens = parameters.get("max_new_tokens", 256) temperature = parameters.get("temperature", 0.7) top_p = parameters.get("top_p", 0.9) do_sample = parameters.get("do_sample", True) repetition_penalty = parameters.get("repetition_penalty", 1.1) return_full_text = parameters.get("return_full_text", False) # Format the input as a chat message if it doesn't already contain instruction formatting if not any(marker in inputs.lower() for marker in ["[inst]", "", "### instruction", "user:", "assistant:"]): formatted_input = f"[INST] {inputs} [/INST]" else: formatted_input = inputs # Tokenize input input_ids = self.tokenizer.encode( formatted_input, return_tensors="pt", truncation=True, max_length=2048 # Reasonable limit for input ) # Move to GPU if available if torch.cuda.is_available(): input_ids = input_ids.cuda() # Generate response with torch.no_grad(): output_ids = self.model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, repetition_penalty=repetition_penalty, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, use_cache=True ) # Decode the response if return_full_text: generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) else: # Only return the newly generated tokens new_tokens = output_ids[0][input_ids.shape[-1]:] generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) # Clean up the response generated_text = generated_text.strip() # Return in the expected format return [{ "generated_text": generated_text, "input_length": input_ids.shape[-1], "output_length": len(output_ids[0]) - input_ids.shape[-1] }] except Exception as e: logger.error(f"Error during inference: {str(e)}") return [{"error": f"Inference failed: {str(e)}"}] def __del__(self): """Clean up resources when the handler is destroyed.""" if hasattr(self, 'model'): del self.model if torch.cuda.is_available(): torch.cuda.empty_cache()