rawcell's picture
Add custom inference handler for DeepSeekV3 architecture
6718e1a verified
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler:
def __init__(self, path: str):
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
def __call__(self, data):
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
if isinstance(inputs, list):
prompt = self.tokenizer.apply_chat_template(
inputs,
tokenize=False,
add_generation_prompt=True
)
else:
prompt = inputs
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
max_new_tokens = parameters.get("max_new_tokens", 512)
temperature = parameters.get("temperature", 0.7)
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id
)
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": generated}]