|
|
from typing import Dict, List, Any |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Custom handler for DoloresAI model - GREEDY DECODING ONLY |
|
|
This avoids sampling issues with resized embeddings. |
|
|
""" |
|
|
|
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the handler with the model and tokenizer. |
|
|
|
|
|
Args: |
|
|
path (str): Path to the model directory |
|
|
""" |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
|
|
|
assert self.model.config.vocab_size == len(self.tokenizer), \ |
|
|
f"Vocab size mismatch: model={self.model.config.vocab_size}, tokenizer={len(self.tokenizer)}" |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Process inference requests using GREEDY DECODING ONLY. |
|
|
|
|
|
Args: |
|
|
data (Dict): Input data with format: |
|
|
{ |
|
|
"inputs": str, # The prompt text |
|
|
"parameters": { # Optional generation parameters |
|
|
"max_new_tokens": int |
|
|
} |
|
|
} |
|
|
|
|
|
Returns: |
|
|
List[Dict]: Generated text response |
|
|
""" |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
|
parameters = data.pop("parameters", {}) |
|
|
|
|
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 512) |
|
|
|
|
|
|
|
|
input_ids = self.tokenizer( |
|
|
inputs, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=self.model.config.max_position_embeddings - max_new_tokens |
|
|
).input_ids.to(self.model.device) |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=False, |
|
|
num_beams=1, |
|
|
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
response_text = generated_text[len(inputs):].strip() |
|
|
|
|
|
return [{"generated_text": response_text}] |
|
|
|