File size: 2,266 Bytes
cef3e59 86c8190 e890fdc cef3e59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from typing import Dict, Any, List, Union
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TextClassificationPipeline,
)
class EndpointHandler:
"""Custom handler for Hugging Face Inference Endpoints.
Loads a fine-tuned text-classification model and exposes a callable
that the endpoint runtime will invoke. The runtime will instantiate
this class once at startup, passing the model directory path.
"""
def __init__(self, path: str = "", **kwargs):
# `path` is the directory where the model artefacts are stored.
# Fallback to current directory if not provided (local testing).
model_dir = path or "."
# Load tokenizer & model
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
# Build a text-classification pipeline
self.pipeline = TextClassificationPipeline(
model=self.model,
tokenizer=self.tokenizer,
device=-1, # CPU; the runtime sets CUDA if available automatically
return_all_scores=False,
function_to_apply="sigmoid"
if getattr(self.model.config, "problem_type", None)
== "multi_label_classification"
else "softmax",
)
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Run inference on the incoming request.
Expected input format from the Inference Endpoint runtime:
{
"inputs": "some text" | ["text 1", "text 2", ...],
"parameters": { ... } # optional pipeline kwargs (e.g., top_k)
}
"""
# Extract the text(s)
raw_inputs = data.get("inputs", data)
if isinstance(raw_inputs, str):
raw_inputs = [raw_inputs]
# Additional pipeline parameters (optional)
parameters = data.get("parameters", {})
# Execute the pipeline
outputs = self.pipeline(raw_inputs, **parameters)
# If only one input was provided, return a single dict for convenience
if len(outputs) == 1:
return outputs[0]
return outputs |