0chanly's picture
Fix handler - remove no_repeat_ngram_size parameter that may not be supported by inference endpoints
72a8810 verified
"""
Custom handler for Constitutional AI models - Fixed version
Removed no_repeat_ngram_size which may not be supported
"""
from typing import Dict, List, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler with model and tokenizer
Args:
path: Path to the model directory
"""
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the inference request
Args:
data: A dictionary containing:
- inputs (str): The input text
- parameters (dict): Generation parameters
Returns:
List containing the generated text
"""
# Get inputs
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Set default parameters to match local chatbot (without no_repeat_ngram_size)
max_new_tokens = parameters.get("max_new_tokens", 180)
temperature = parameters.get("temperature", 0.7)
do_sample = parameters.get("do_sample", True)
top_p = parameters.get("top_p", 0.9)
top_k = parameters.get("top_k", 50)
repetition_penalty = parameters.get("repetition_penalty", 1.2)
# REMOVED: no_repeat_ngram_size - may not be supported
# Tokenize
input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
# Move to same device as model
if torch.cuda.is_available():
input_ids = input_ids.cuda()
# Generate with parameters matching local chatbot (minus unsupported params)
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
# REMOVED: no_repeat_ngram_size
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the input prompt from the output
if generated_text.startswith(inputs):
generated_text = generated_text[len(inputs):].strip()
return [{"generated_text": generated_text}]