Update custom model files, README, and requirements
Browse files- asr_pipeline.py +5 -0
asr_pipeline.py
CHANGED
|
@@ -460,12 +460,17 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
| 460 |
return {"tokens": generated_ids, "is_last": is_last}
|
| 461 |
|
| 462 |
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
|
|
|
|
|
|
|
| 463 |
# Handle list of outputs (from chunking)
|
| 464 |
if isinstance(model_outputs, list):
|
|
|
|
| 465 |
model_outputs = model_outputs[0] if model_outputs else {}
|
| 466 |
|
| 467 |
tokens = model_outputs.get("tokens")
|
|
|
|
| 468 |
if tokens is None:
|
|
|
|
| 469 |
return super().postprocess(model_outputs, **kwargs)
|
| 470 |
|
| 471 |
if torch.is_tensor(tokens):
|
|
|
|
| 460 |
return {"tokens": generated_ids, "is_last": is_last}
|
| 461 |
|
| 462 |
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
|
| 463 |
+
print(f"[DEBUG postprocess] Called with type: {type(model_outputs)}")
|
| 464 |
+
|
| 465 |
# Handle list of outputs (from chunking)
|
| 466 |
if isinstance(model_outputs, list):
|
| 467 |
+
print(f"[DEBUG postprocess] List with {len(model_outputs)} items")
|
| 468 |
model_outputs = model_outputs[0] if model_outputs else {}
|
| 469 |
|
| 470 |
tokens = model_outputs.get("tokens")
|
| 471 |
+
print(f"[DEBUG postprocess] tokens is None: {tokens is None}")
|
| 472 |
if tokens is None:
|
| 473 |
+
print("[DEBUG postprocess] Falling back to super().postprocess()")
|
| 474 |
return super().postprocess(model_outputs, **kwargs)
|
| 475 |
|
| 476 |
if torch.is_tensor(tokens):
|