E-Hospital commited on
Commit
794ae9d
·
1 Parent(s): cbba02d

Set device=0 for pipeline

Browse files
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -8,7 +8,7 @@ class EndpointHandler:
8
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto")
9
  self.tokenizer = AutoTokenizer.from_pretrained(path)
10
 
11
- self.pipeline = pipeline(task="text-generation", tokenizer=self.tokenizer, device_map="auto", framework="pt", model=self.model, max_length=512)
12
 
13
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
14
  """
@@ -25,9 +25,9 @@ class EndpointHandler:
25
 
26
  # pass inputs with all kwargs in data
27
  if parameters is not None:
28
- prediction = self.pipeline(inputs, **parameters)
29
  else:
30
- prediction = self.pipeline(inputs)
31
 
32
  # postprocess the prediction
33
  prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
8
  self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto")
9
  self.tokenizer = AutoTokenizer.from_pretrained(path)
10
 
11
+ self.pipeline = pipeline(task="text-generation", tokenizer=self.tokenizer, device=0, device_map="auto", framework="pt", model=self.model, max_length=512)
12
 
13
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
14
  """
 
25
 
26
  # pass inputs with all kwargs in data
27
  if parameters is not None:
28
+ prediction = self.pipeline(inputs, device=0, **parameters)
29
  else:
30
+ prediction = self.pipeline(inputs, device=0)
31
 
32
  # postprocess the prediction
33
  prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)