|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
Json = Dict[str, Any] |
|
|
Messages = List[Dict[str, str]] |
|
|
|
|
|
|
|
|
def _is_messages(x: Any) -> bool: |
|
|
return ( |
|
|
isinstance(x, list) |
|
|
and len(x) > 0 |
|
|
and all(isinstance(m, dict) and "role" in m and "content" in m for m in x) |
|
|
) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Hugging Face Inference Endpoints custom handler. |
|
|
Expects: |
|
|
- request body is a dict |
|
|
- always contains `inputs` |
|
|
- may contain `parameters` for generation |
|
|
""" |
|
|
|
|
|
def __init__(self, model_dir: str): |
|
|
self.model_dir = model_dir |
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if self.device == "cuda": |
|
|
|
|
|
self.dtype = torch.bfloat16 |
|
|
else: |
|
|
self.dtype = torch.float32 |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
use_fast=True, |
|
|
) |
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=self.dtype, |
|
|
device_map="auto" if self.device == "cuda" else None, |
|
|
) |
|
|
|
|
|
if self.device != "cuda": |
|
|
self.model.to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
@torch.inference_mode() |
|
|
def __call__(self, data: Json) -> Union[Json, List[Json]]: |
|
|
inputs = data.get("inputs", "") |
|
|
params = data.get("parameters", {}) or {} |
|
|
|
|
|
|
|
|
max_new_tokens = int(params.get("max_new_tokens", 256)) |
|
|
temperature = float(params.get("temperature", 0.7)) |
|
|
top_p = float(params.get("top_p", 0.95)) |
|
|
top_k = int(params.get("top_k", 0)) |
|
|
repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
|
|
|
|
|
do_sample = bool(params.get("do_sample", temperature > 0)) |
|
|
num_beams = int(params.get("num_beams", 1)) |
|
|
|
|
|
def _one(item: Any) -> Json: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(item, dict) and "messages" in item: |
|
|
item = item["messages"] |
|
|
|
|
|
if _is_messages(item): |
|
|
|
|
|
input_ids = self.tokenizer.apply_chat_template( |
|
|
item, |
|
|
return_tensors="pt", |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
else: |
|
|
if not isinstance(item, str): |
|
|
item = str(item) |
|
|
enc = self.tokenizer(item, return_tensors="pt") |
|
|
input_ids = enc["input_ids"] |
|
|
|
|
|
input_ids = input_ids.to(self.model.device) |
|
|
input_len = input_ids.shape[-1] |
|
|
|
|
|
gen_ids = self.model.generate( |
|
|
input_ids=input_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=do_sample, |
|
|
temperature=temperature if do_sample else None, |
|
|
top_p=top_p if do_sample else None, |
|
|
top_k=top_k if do_sample and top_k > 0 else None, |
|
|
num_beams=num_beams, |
|
|
repetition_penalty=repetition_penalty, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
new_tokens = gen_ids[0, input_len:] |
|
|
text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
return {"generated_text": text} |
|
|
|
|
|
|
|
|
if isinstance(inputs, list) and not _is_messages(inputs): |
|
|
return [_one(x) for x in inputs] |
|
|
else: |
|
|
return _one(inputs) |
|
|
|