import torch import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) from prediction import main class EndpointHandler: def __init__(self, model_dir, **kwargs): # Load your model (.pt file) model_path = f"{model_dir}/src/model/rellow-2.pt" self.model = torch.load(model_path, map_location="cpu") self.model.eval() def __call__(self, data: dict): inputs = data.get("words", []) if not inputs or len(inputs) != 3: return {"error": "Expected exactly three words"} output = main(words=inputs) return {"generated": output}