othsueh commited on
Commit
f704679
·
verified ·
1 Parent(s): 6ce85cd

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -15
handler.py CHANGED
@@ -8,35 +8,23 @@ from modeling_upstream_finetune import UpstreamFinetune
8
  class EndpointHandler:
9
  def __init__(self, model_dir: str, **kwargs: Any) -> None:
10
  # Load config and model with trust_remote_code
11
- self.config = AutoConfig.from_pretrained(
12
- model_dir, trust_remote_code=True
13
- )
14
  self.model = UpstreamFinetune.from_pretrained(
15
  model_dir,
 
16
  trust_remote_code=True,
17
  # pass any kwargs like device mapping
18
  )
19
  self.model.eval()
20
- # Load processor (feature extractor + tokenizer)
21
- self.processor = AutoProcessor.from_pretrained(
22
- model_dir, trust_remote_code=True
23
- )
24
 
25
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
26
  # Expect raw audio bytes or a base64 string in `data["inputs"]`
27
  audio = data["inputs"]
28
  sr = data.get("sampling_rate", 16000)
29
- # Preprocess
30
- inputs = self.processor(
31
- audio,
32
- sampling_rate=sr,
33
- return_tensors="pt",
34
- padding=True
35
- )
36
  # Forward pass
37
  with torch.no_grad():
38
  cat_logits, reg_outputs = self.model(
39
- inputs.input_values.squeeze(0),
40
  sr
41
  )
42
  # Postprocess to Python types
 
8
  class EndpointHandler:
9
  def __init__(self, model_dir: str, **kwargs: Any) -> None:
10
  # Load config and model with trust_remote_code
11
+ device = 'cuda'
 
 
12
  self.model = UpstreamFinetune.from_pretrained(
13
  model_dir,
14
+ device=device,
15
  trust_remote_code=True,
16
  # pass any kwargs like device mapping
17
  )
18
  self.model.eval()
 
 
 
 
19
 
20
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
21
  # Expect raw audio bytes or a base64 string in `data["inputs"]`
22
  audio = data["inputs"]
23
  sr = data.get("sampling_rate", 16000)
 
 
 
 
 
 
 
24
  # Forward pass
25
  with torch.no_grad():
26
  cat_logits, reg_outputs = self.model(
27
+ audio,
28
  sr
29
  )
30
  # Postprocess to Python types