File size: 3,065 Bytes
643b247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Optional, Union

import torch
import transformers
from transformers import ProcessorMixin

try:
    from .asr_config import ASRConfig
except ImportError:
    from asr_config import ASRConfig  # type: ignore[no-redef]


class ASRProcessor(ProcessorMixin):
    """Processor for Whisper-based ASR models."""

    attributes = ["feature_extractor", "tokenizer"]
    feature_extractor_class = "AutoFeatureExtractor"
    tokenizer_class = "AutoTokenizer"
    AUDIO_TOKEN = "<audio>"
    TRANSCRIBE_PROMPT = "Transcribe: "

    def __init__(self, feature_extractor, tokenizer):
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)

    def __call__(
        self,
        audio: Optional[Union[list, "torch.Tensor"]] = None,
        text: Optional[str] = None,
        system_prompt: Optional[str] = None,
        return_tensors: str = "pt",
        **kwargs,
    ) -> dict:
        """Process audio and text inputs for inference.

        Args:
            audio: Raw audio waveform(s)
            text: Target transcription (optional, for training - but use DataCollator instead)
            system_prompt: Optional system prompt
            return_tensors: Return format ("pt" for PyTorch)

        Returns:
            Dict with input_features, input_ids, attention_mask
        """
        result = {}

        # Process audio
        if audio is not None:
            audio_inputs = self.feature_extractor(
                audio,
                sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
                return_tensors=return_tensors,
                **kwargs,
            )
            result["input_features"] = audio_inputs["input_features"]
            # Whisper encoder output length = mel_len // 2 (stride-2 conv)
            num_audio_tokens = audio_inputs["input_features"].shape[-1] // 2
        else:
            num_audio_tokens = 0

        # Build prompt with audio token placeholders
        user_content = self.TRANSCRIBE_PROMPT
        if num_audio_tokens > 0:
            user_content += self.AUDIO_TOKEN * num_audio_tokens

        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": user_content})
        if text is not None:
            messages.append({"role": "assistant", "content": text})

        # Tokenize
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=(text is None),
            return_tensors=return_tensors,
        )

        if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)

        result["input_ids"] = input_ids
        result["attention_mask"] = torch.ones_like(input_ids)

        return result


ASRProcessor.register_for_auto_class()
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)