|
|
from typing import Dict, List, Any |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the model and tokenizer using the local path. |
|
|
Uses Zenith Coder v1.1 custom code (modeling_deepseek.py, configuration_deepseek.py, tokenization_deepseek_fast.py). |
|
|
""" |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
path, trust_remote_code=True |
|
|
) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" |
|
|
) |
|
|
self.model.eval() |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Accepts a dictionary with the prompt and optional `max_new_tokens`. |
|
|
Returns generated text. |
|
|
""" |
|
|
prompt = data.get("inputs") or data.get("prompt") |
|
|
if not prompt or not isinstance(prompt, str): |
|
|
return [{"error": "No valid input prompt provided."}] |
|
|
|
|
|
max_new_tokens = int(data.get("max_new_tokens", 256)) |
|
|
temperature = float(data.get("temperature", 1.0)) |
|
|
top_p = float(data.get("top_p", 0.95)) |
|
|
top_k = int(data.get("top_k", 50)) |
|
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids |
|
|
if torch.cuda.is_available(): |
|
|
input_ids = input_ids.cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = self.model.generate( |
|
|
input_ids, |
|
|
do_sample=True, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
gen_text = self.tokenizer.decode( |
|
|
generated_ids[0][input_ids.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
return [{"generated_text": gen_text}] |
|
|
|