""" Custom handler for Constitutional AI models - Fixed version Removed no_repeat_ngram_size which may not be supported """ from typing import Dict, List, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler: def __init__(self, path=""): """ Initialize the handler with model and tokenizer Args: path: Path to the model directory """ # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True ) self.model.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the inference request Args: data: A dictionary containing: - inputs (str): The input text - parameters (dict): Generation parameters Returns: List containing the generated text """ # Get inputs inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Set default parameters to match local chatbot (without no_repeat_ngram_size) max_new_tokens = parameters.get("max_new_tokens", 180) temperature = parameters.get("temperature", 0.7) do_sample = parameters.get("do_sample", True) top_p = parameters.get("top_p", 0.9) top_k = parameters.get("top_k", 50) repetition_penalty = parameters.get("repetition_penalty", 1.2) # REMOVED: no_repeat_ngram_size - may not be supported # Tokenize input_ids = self.tokenizer.encode(inputs, return_tensors="pt") # Move to same device as model if torch.cuda.is_available(): input_ids = input_ids.cuda() # Generate with parameters matching local chatbot (minus unsupported params) with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=do_sample, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, # REMOVED: no_repeat_ngram_size pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decode generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the input prompt from the output if generated_text.startswith(inputs): generated_text = generated_text[len(inputs):].strip() return [{"generated_text": generated_text}]