cxc / handler.py
min-samis2's picture
Add HF Inference Endpoints custom handler
cf062ee verified
Raw
History Blame Contribute Delete
1.51 kB
from typing import Any, Dict
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
class EndpointHandler:
def __init__(self, path: str = ""):
cfg = PeftConfig.from_pretrained(path)
base = cfg.base_model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(base)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
base,
torch_dtype=torch.float16,
device_map="auto",
)
self.model = PeftModel.from_pretrained(model, path)
self.model.eval()
def __call__(self, data: Dict[str, Any]):
inputs = data.get("inputs", "")
params = data.get("parameters", {}) or {}
enc = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
with torch.no_grad():
out = self.model.generate(
**enc,
max_new_tokens=int(params.get("max_new_tokens", 256)),
temperature=float(params.get("temperature", 0.7)),
top_p=float(params.get("top_p", 0.9)),
do_sample=bool(params.get("do_sample", True)),
pad_token_id=self.tokenizer.pad_token_id,
)
text = self.tokenizer.decode(
out[0][enc["input_ids"].shape[1]:], skip_special_tokens=True
)
return [{"generated_text": text}]