juan4pro12 commited on
Commit
b4e788b
·
verified ·
1 Parent(s): 7d38920

Fix custom handler runtime compatibility

Browse files
Files changed (3) hide show
  1. README.md +4 -11
  2. handler.py +13 -7
  3. requirements.txt +2 -1
README.md CHANGED
@@ -2,15 +2,8 @@
2
 
3
  Custom handler para desplegar `microsoft/VibeVoice-ASR-HF` en un Inference Endpoint dedicado de Hugging Face.
4
 
5
- ## Archivos
6
 
7
- - `handler.py`: handler custom para el endpoint.
8
- - `requirements.txt`: dependencias adicionales.
9
- - `deploy_endpoint.py`: script de referencia para desplegar el endpoint dedicado.
10
-
11
- ## Configuracion esperada
12
-
13
- - Repo destino en HF: `juan4pro12/vibevoice-custom-handler`
14
- - Endpoint dedicado protegido con token
15
- - Hardware: `nvidia-t4` / `small`
16
- - Task: `custom`
 
2
 
3
  Custom handler para desplegar `microsoft/VibeVoice-ASR-HF` en un Inference Endpoint dedicado de Hugging Face.
4
 
5
+ ## Notas
6
 
7
+ - Usa `task=custom`.
8
+ - Usa GPU T4 (`aws-us-east-1-nvidia-t4-x1`).
9
+ - La dependencia de `transformers` se instala desde el fork recomendado por VibeVoice para mantener compatibilidad con el runtime del endpoint.
 
 
 
 
 
 
 
handler.py CHANGED
@@ -13,13 +13,19 @@ class EndpointHandler:
13
 
14
  def __call__(self, data):
15
  inputs_data = data.pop("inputs", data)
16
- inputs = self.processor(audio=inputs_data, return_tensors="pt").to(
17
- self.model.device,
18
- self.model.dtype,
19
- )
 
 
20
 
21
  with torch.no_grad():
22
- generated_ids = self.model.generate(**inputs)
23
 
24
- transcription = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
25
- return {"text": transcription[0]}
 
 
 
 
 
13
 
14
  def __call__(self, data):
15
  inputs_data = data.pop("inputs", data)
16
+ prompt = data.pop("prompt", None)
17
+ inputs = self.processor.apply_transcription_request(
18
+ audio=inputs_data,
19
+ prompt=prompt,
20
+ return_tensors="pt",
21
+ ).to(self.model.device, self.model.dtype)
22
 
23
  with torch.no_grad():
24
+ output_ids = self.model.generate(**inputs)
25
 
26
+ generated_ids = output_ids[:, inputs["input_ids"].shape[1]:]
27
+ transcription = self.processor.decode(
28
+ generated_ids,
29
+ return_format="transcription_only",
30
+ )[0]
31
+ return {"text": transcription}
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  torch
2
- transformers>=5.3.0
3
  accelerate
4
  soundfile
5
  librosa
 
 
 
1
  torch
 
2
  accelerate
3
  soundfile
4
  librosa
5
+ sentencepiece
6
+ git+https://github.com/ebezzam/transformers.git@vibevoice_asr