andrewluo commited on
Commit
e81d4fe
·
1 Parent(s): b50421a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -2
handler.py CHANGED
@@ -9,7 +9,11 @@ class EndpointHandler():
9
  model = AutoModelForMaskedLM.from_pretrained(path)
10
  self.tokenizer = tokenizer
11
  self.model = model
12
-
 
 
 
 
13
  def __call__(self, data: Dict[str, Any]) -> List[Dict[Any, Any]]:
14
  """
15
  data args:
@@ -20,7 +24,7 @@ class EndpointHandler():
20
  """
21
  # get inputs
22
  text = data.pop("text", data)
23
- tokens = self.tokenizer(text, return_tensors='pt', padding=True)
24
  outputs = self.model(**tokens)
25
  results = []
26
  for idx, x in enumerate(outputs.logits):
 
9
  model = AutoModelForMaskedLM.from_pretrained(path)
10
  self.tokenizer = tokenizer
11
  self.model = model
12
+ if torch.cuda.is_available():
13
+ self.device = torch.device("cuda")
14
+ self.model.to(self.device)
15
+ else:
16
+ self.device = torch.device("cpu")
17
  def __call__(self, data: Dict[str, Any]) -> List[Dict[Any, Any]]:
18
  """
19
  data args:
 
24
  """
25
  # get inputs
26
  text = data.pop("text", data)
27
+ tokens = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device)
28
  outputs = self.model(**tokens)
29
  results = []
30
  for idx, x in enumerate(outputs.logits):