File size: 4,567 Bytes
c90fe04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# handler.py
from __future__ import annotations
from typing import Any, Dict, List, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
Json = Dict[str, Any]
Messages = List[Dict[str, str]] # [{"role":"user|assistant|system", "content":"..."}]
def _is_messages(x: Any) -> bool:
return (
isinstance(x, list)
and len(x) > 0
and all(isinstance(m, dict) and "role" in m and "content" in m for m in x)
)
class EndpointHandler:
"""
Hugging Face Inference Endpoints custom handler.
Expects:
- request body is a dict
- always contains `inputs`
- may contain `parameters` for generation
"""
def __init__(self, model_dir: str):
self.model_dir = model_dir
# Pick dtype/device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.device == "cuda":
# bfloat16 is usually safe on A100/H100; if your instance doesn't support bf16, change to float16
self.dtype = torch.bfloat16
else:
self.dtype = torch.float32
# IMPORTANT: trust_remote_code=True because repo contains AsteriskForCausalLM.py + auto_map
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir,
trust_remote_code=True,
use_fast=True,
)
# Make sure pad token exists (your config uses pad_token_id=2 which equals eos_token_id in many llama-like models)
if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=self.dtype,
device_map="auto" if self.device == "cuda" else None,
)
if self.device != "cuda":
self.model.to(self.device)
self.model.eval()
@torch.inference_mode()
def __call__(self, data: Json) -> Union[Json, List[Json]]:
inputs = data.get("inputs", "")
params = data.get("parameters", {}) or {}
# Generation defaults (can be overridden via `parameters`)
max_new_tokens = int(params.get("max_new_tokens", 256))
temperature = float(params.get("temperature", 0.7))
top_p = float(params.get("top_p", 0.95))
top_k = int(params.get("top_k", 0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
do_sample = bool(params.get("do_sample", temperature > 0))
num_beams = int(params.get("num_beams", 1))
def _one(item: Any) -> Json:
# Accept:
# 1) string prompt
# 2) messages list: [{"role":"user","content":"..."}]
# 3) dict {"messages":[...]} (common chat style)
if isinstance(item, dict) and "messages" in item:
item = item["messages"]
if _is_messages(item):
# Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
input_ids = self.tokenizer.apply_chat_template(
item,
return_tensors="pt",
add_generation_prompt=True,
)
else:
if not isinstance(item, str):
item = str(item)
enc = self.tokenizer(item, return_tensors="pt")
input_ids = enc["input_ids"]
input_ids = input_ids.to(self.model.device)
input_len = input_ids.shape[-1]
gen_ids = self.model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature if do_sample else None,
top_p=top_p if do_sample else None,
top_k=top_k if do_sample and top_k > 0 else None,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Only return newly generated tokens
new_tokens = gen_ids[0, input_len:]
text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return {"generated_text": text}
# Batch support
if isinstance(inputs, list) and not _is_messages(inputs):
return [_one(x) for x in inputs]
else:
return _one(inputs)
|