from typing import Dict, List, Any from transformers import AutoModelForCausalLM, AutoTokenizer import torch class EndpointHandler: """ Custom handler for DoloresAI model - GREEDY DECODING ONLY This avoids sampling issues with resized embeddings. """ def __init__(self, path=""): """ Initialize the handler with the model and tokenizer. Args: path (str): Path to the model directory """ # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True ) # Verify vocab sizes match assert self.model.config.vocab_size == len(self.tokenizer), \ f"Vocab size mismatch: model={self.model.config.vocab_size}, tokenizer={len(self.tokenizer)}" def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: """ Process inference requests using GREEDY DECODING ONLY. Args: data (Dict): Input data with format: { "inputs": str, # The prompt text "parameters": { # Optional generation parameters "max_new_tokens": int } } Returns: List[Dict]: Generated text response """ # Extract inputs inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Get max tokens (only parameter we use) max_new_tokens = parameters.get("max_new_tokens", 512) # Tokenize input input_ids = self.tokenizer( inputs, return_tensors="pt", truncation=True, max_length=self.model.config.max_position_embeddings - max_new_tokens ).input_ids.to(self.model.device) # Generate response with GREEDY DECODING ONLY # This is stable and avoids NaN/inf errors from sampling with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=max_new_tokens, do_sample=False, # GREEDY - no sampling num_beams=1, # No beam search pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Decode output generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the input prompt from the response response_text = generated_text[len(inputs):].strip() return [{"generated_text": response_text}]