|
|
|
|
|
from typing import Any, Dict, List |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
Json = Dict[str, Any] |
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Minimal custom handler for Hugging Face Inference Endpoints. |
|
|
|
|
|
Implements __init__() to load the model/tokenizer, |
|
|
and __call__() to handle inference requests. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_dir: str): |
|
|
""" |
|
|
Called once on endpoint startup. |
|
|
|
|
|
Args: |
|
|
model_dir (str): Local path where the model repo was downloaded. |
|
|
""" |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
use_fast=True, |
|
|
) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
@torch.inference_mode() |
|
|
def __call__(self, data: Json) -> List[Json]: |
|
|
""" |
|
|
Called for each inference request. |
|
|
|
|
|
Args: |
|
|
data (dict): {"inputs": str or list[str], "parameters": {...}} |
|
|
|
|
|
Returns: |
|
|
List[dict]: list of output dicts (each must be serializable). |
|
|
""" |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
params = data.get("parameters", {}) or {} |
|
|
|
|
|
|
|
|
enc = self.tokenizer( |
|
|
inputs, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
) |
|
|
|
|
|
input_ids = enc["input_ids"] |
|
|
attention_mask = enc["attention_mask"] |
|
|
|
|
|
|
|
|
device = next(self.model.parameters()).device |
|
|
input_ids = input_ids.to(device) |
|
|
attention_mask = attention_mask.to(device) |
|
|
|
|
|
|
|
|
max_new_tokens = int(params.get("max_new_tokens", 128)) |
|
|
temperature = float(params.get("temperature", 1.0)) |
|
|
|
|
|
|
|
|
output_ids = self.model.generate( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
) |
|
|
|
|
|
|
|
|
outputs = [] |
|
|
for seq in output_ids: |
|
|
text = self.tokenizer.decode(seq, skip_special_tokens=True) |
|
|
outputs.append({"generated_text": text}) |
|
|
|
|
|
return outputs |
|
|
|