File size: 2,820 Bytes
b219423 0aee091 b219423 0aee091 b219423 0aee091 b219423 0aee091 b219423 e904f37 b219423 0aee091 b219423 0aee091 cfb29ca b219423 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler:
"""
Custom handler for DoloresAI model - GREEDY DECODING ONLY
This avoids sampling issues with resized embeddings.
"""
def __init__(self, path=""):
"""
Initialize the handler with the model and tokenizer.
Args:
path (str): Path to the model directory
"""
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
# Verify vocab sizes match
assert self.model.config.vocab_size == len(self.tokenizer), \
f"Vocab size mismatch: model={self.model.config.vocab_size}, tokenizer={len(self.tokenizer)}"
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""
Process inference requests using GREEDY DECODING ONLY.
Args:
data (Dict): Input data with format:
{
"inputs": str, # The prompt text
"parameters": { # Optional generation parameters
"max_new_tokens": int
}
}
Returns:
List[Dict]: Generated text response
"""
# Extract inputs
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Get max tokens (only parameter we use)
max_new_tokens = parameters.get("max_new_tokens", 512)
# Tokenize input
input_ids = self.tokenizer(
inputs,
return_tensors="pt",
truncation=True,
max_length=self.model.config.max_position_embeddings - max_new_tokens
).input_ids.to(self.model.device)
# Generate response with GREEDY DECODING ONLY
# This is stable and avoids NaN/inf errors from sampling
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=False, # GREEDY - no sampling
num_beams=1, # No beam search
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode output
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the input prompt from the response
response_text = generated_text[len(inputs):].strip()
return [{"generated_text": response_text}]
|