rellow / handler.py
rafacamargo's picture
chore: add necessary files for huggingface to expose an inference endpoint to the llm
751ad61
import torch
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
from prediction import main
class EndpointHandler:
def __init__(self, model_dir, **kwargs):
# Load your model (.pt file)
model_path = f"{model_dir}/src/model/rellow-2.pt"
self.model = torch.load(model_path, map_location="cpu")
self.model.eval()
def __call__(self, data: dict):
inputs = data.get("words", [])
if not inputs or len(inputs) != 3:
return {"error": "Expected exactly three words"}
output = main(words=inputs)
return {"generated": output}