date2format / handler.py
syarulzaffi's picture
Create handler.py
b8473d8 verified
raw
history blame contribute delete
885 Bytes
from typing import Dict, List, Any
from transformers import pipeline, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
# Load the optimized model
tokenizer = AutoTokenizer.from_pretrained(path)
# Create inference pipeline for text classification
self.pipeline = pipeline("text-classification", model=path, tokenizer=tokenizer)
def __call__(self, data: str) -> List[List[Dict[str, float]]]:
"""
Args:
data (str): A raw string input for inference.
Returns:
A list containing the prediction results:
A list of one list, e.g., [[{"label": "LABEL", "score": 0.99}]]
"""
# Pass the data as `text` directly
inputs = data.pop("inputs", data)
prediction = self.pipeline(inputs)
# Return the prediction result
return prediction