File size: 592 Bytes
751ad61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
from prediction import main

class EndpointHandler:
  def __init__(self, model_dir, **kwargs):
    # Load your model (.pt file)
    model_path = f"{model_dir}/src/model/rellow-2.pt"
    self.model = torch.load(model_path, map_location="cpu")
    self.model.eval()

  def __call__(self, data: dict):
    inputs = data.get("words", [])
    if not inputs or len(inputs) != 3:
      return {"error": "Expected exactly three words"}
    output = main(words=inputs)
    return {"generated": output}