| import torch | |
| import sys | |
| import os | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) | |
| from prediction import main | |
| class EndpointHandler: | |
| def __init__(self, model_dir, **kwargs): | |
| # Load your model (.pt file) | |
| model_path = f"{model_dir}/src/model/rellow-2.pt" | |
| self.model = torch.load(model_path, map_location="cpu") | |
| self.model.eval() | |
| def __call__(self, data: dict): | |
| inputs = data.get("words", []) | |
| if not inputs or len(inputs) != 3: | |
| return {"error": "Expected exactly three words"} | |
| output = main(words=inputs) | |
| return {"generated": output} | |