gatortron-base / handler.py
Zigla's picture
Upload folder using huggingface_hub
62a40cd verified
from transformers import AutoModel, AutoTokenizer
import torch
class EndpointHandler:
def __init__(self, model_dir):
self.model = None
self.tokenizer = None
self.model_dir = model_dir
def __call__(self, data):
# Initialize model if not already initialized
if self.model is None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.model = AutoModel.from_pretrained(self.model_dir)
# Process input data
inputs = data.get("inputs")
if isinstance(inputs, str):
inputs = [inputs]
# Tokenize and get model outputs
encoded_input = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = self.model(**encoded_input)
# Return the last hidden state
return {"outputs": outputs.last_hidden_state.tolist()}