File size: 3,065 Bytes
643b247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from typing import Optional, Union
import torch
import transformers
from transformers import ProcessorMixin
try:
from .asr_config import ASRConfig
except ImportError:
from asr_config import ASRConfig # type: ignore[no-redef]
class ASRProcessor(ProcessorMixin):
"""Processor for Whisper-based ASR models."""
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
AUDIO_TOKEN = "<audio>"
TRANSCRIBE_PROMPT = "Transcribe: "
def __init__(self, feature_extractor, tokenizer):
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
def __call__(
self,
audio: Optional[Union[list, "torch.Tensor"]] = None,
text: Optional[str] = None,
system_prompt: Optional[str] = None,
return_tensors: str = "pt",
**kwargs,
) -> dict:
"""Process audio and text inputs for inference.
Args:
audio: Raw audio waveform(s)
text: Target transcription (optional, for training - but use DataCollator instead)
system_prompt: Optional system prompt
return_tensors: Return format ("pt" for PyTorch)
Returns:
Dict with input_features, input_ids, attention_mask
"""
result = {}
# Process audio
if audio is not None:
audio_inputs = self.feature_extractor(
audio,
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
return_tensors=return_tensors,
**kwargs,
)
result["input_features"] = audio_inputs["input_features"]
# Whisper encoder output length = mel_len // 2 (stride-2 conv)
num_audio_tokens = audio_inputs["input_features"].shape[-1] // 2
else:
num_audio_tokens = 0
# Build prompt with audio token placeholders
user_content = self.TRANSCRIBE_PROMPT
if num_audio_tokens > 0:
user_content += self.AUDIO_TOKEN * num_audio_tokens
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
if text is not None:
messages.append({"role": "assistant", "content": text})
# Tokenize
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=(text is None),
return_tensors=return_tensors,
)
if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
result["input_ids"] = input_ids
result["attention_mask"] = torch.ones_like(input_ids)
return result
ASRProcessor.register_for_auto_class()
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
|