File size: 592 Bytes
751ad61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
from prediction import main
class EndpointHandler:
def __init__(self, model_dir, **kwargs):
# Load your model (.pt file)
model_path = f"{model_dir}/src/model/rellow-2.pt"
self.model = torch.load(model_path, map_location="cpu")
self.model.eval()
def __call__(self, data: dict):
inputs = data.get("words", [])
if not inputs or len(inputs) != 3:
return {"error": "Expected exactly three words"}
output = main(words=inputs)
return {"generated": output}
|