H-H-E commited on
Commit
6eea13c
·
1 Parent(s): 3e55f57

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -6
handler.py CHANGED
@@ -1,9 +1,10 @@
1
  from typing import Dict, List, Any
2
- from transformers import pipeline
3
 
4
  class EndpointHandler():
5
  def __init__(self, path=""):
6
- self.pipeline = pipeline("text-to-speech", "suno/bark")
 
7
 
8
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
9
  """
@@ -14,8 +15,13 @@ class EndpointHandler():
14
  A :obj:`list` | `dict`: will be serialized and returned
15
  """
16
  # get inputs
17
- inputs = data.pop("inputs",data)
 
 
 
 
 
 
 
 
18
 
19
- # run normal prediction
20
- prediction = self.pipeline(inputs)
21
- return prediction
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoProcessor,BarkModel
3
 
4
  class EndpointHandler():
5
  def __init__(self, path=""):
6
+ self.model = BarkModel.from_pretrained("suno/bark-small")
7
+ self.processor = AutoProcessor.from_pretrained("suno/bark")
8
 
9
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
10
  """
 
15
  A :obj:`list` | `dict`: will be serialized and returned
16
  """
17
  # get inputs
18
+ text_prompt = data.pop("inputs",data)
19
+
20
+ inputs = processor(text_prompt)
21
+
22
+ # run normal prediction
23
+ speech_output = model.generate(**inputs.to(device))
24
+ return speech_output
25
+
26
+
27