File size: 7,180 Bytes
704d932
 
 
 
 
 
cde5368
704d932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b34148
704d932
 
3b34148
704d932
cde5368
 
 
704d932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d95e2f8
704d932
 
 
 
 
 
 
cde5368
 
704d932
cde5368
704d932
d95e2f8
704d932
 
 
 
 
 
d95e2f8
 
 
 
704d932
 
 
 
 
 
 
d95e2f8
704d932
 
 
 
 
 
 
cde5368
 
704d932
 
 
 
 
 
cde5368
 
704d932
 
 
cde5368
 
 
 
 
 
 
 
 
 
 
 
704d932
 
 
 
cde5368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d95e2f8
 
 
 
704d932
 
 
 
d95e2f8
704d932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
Borealis Processor for HuggingFace/vLLM compatibility.

Handles audio feature extraction and tokenization.
"""

import numpy as np
from typing import List, Optional, Union

import torch
from transformers import ProcessorMixin, BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from transformers import AutoTokenizer


class BorealisProcessor(ProcessorMixin):
    """
    Processor for Borealis audio-language model.

    Combines WhisperFeatureExtractor for audio and Qwen3 tokenizer for text.
    """

    attributes = ["feature_extractor", "tokenizer"]
    feature_extractor_class = "WhisperFeatureExtractor"
    tokenizer_class = "AutoTokenizer"

    # Audio tokens (checkpoint has only 2 special tokens: 151669 and 151670)
    audio_token = "<|AUDIO|>"
    audio_bos_token = "<|start_of_audio|>"
    audio_eos_token = "<|start_of_audio|>"  # Reuse bos token since only 2 audio tokens in vocab

    # Borealis architecture parameters
    downsample_factor = 4  # Audio embedding downsampling factor

    def __init__(
        self,
        feature_extractor: Optional[WhisperFeatureExtractor] = None,
        tokenizer: Optional[AutoTokenizer] = None,
        **kwargs
    ):
        if feature_extractor is None:
            feature_extractor = WhisperFeatureExtractor.from_pretrained(
                "openai/whisper-large-v3"
            )
        if tokenizer is None:
            tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

        super().__init__(feature_extractor, tokenizer)

    @property
    def sampling_rate(self) -> int:
        return self.feature_extractor.sampling_rate

    def __call__(
        self,
        text: Optional[Union[str, List[str]]] = None,
        audio: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
        audios: Optional[List] = None,  # vLLM uses plural
        sampling_rate: Optional[int] = None,
        return_tensors: Optional[str] = "pt",
        **kwargs,
    ) -> BatchFeature:
        """
        Process text and/or audio inputs.

        Expands <|AUDIO|> tokens in text to match the number of audio embeddings.

        Args:
            text: Text prompt(s) containing <|AUDIO|> placeholders
            audio: Audio waveform(s) at 16kHz
            audios: Audio waveform(s) at 16kHz (vLLM style)
            sampling_rate: Audio sampling rate (default: 16000)
            return_tensors: Return tensor type

        Returns:
            BatchFeature with input_ids and optionally input_features
        """
        # vLLM uses 'audios' (plural)
        if audios is not None and audio is None:
            audio = audios

        if sampling_rate is None:
            sampling_rate = self.sampling_rate

        data = {}

        # Process audio if provided
        if audio is not None:
            if not isinstance(audio, list):
                audio = [audio]

            # Convert to numpy for feature extractor
            audio_arrays = []
            for a in audio:
                if isinstance(a, torch.Tensor):
                    a = a.numpy()
                if isinstance(a, np.ndarray):
                    a = a.astype(np.float32)
                audio_arrays.append(a)

            audio_features = self.feature_extractor(
                audio_arrays,
                sampling_rate=sampling_rate,
                return_tensors=return_tensors,
                padding="max_length",
                return_attention_mask=True,
            )
            data["input_features"] = audio_features.input_features

            # Calculate audio lengths for token expansion
            # Whisper uses 30s chunks with 3000 mel frames -> 1500 encoder frames
            # Borealis downsamples by 4x -> 375 tokens
            attention_mask = audio_features.get("attention_mask")
            if attention_mask is not None:
                # Sum attention mask to get actual audio length in frames
                audio_lengths = attention_mask.sum(dim=-1).tolist()
            else:
                # Default: assume full 30s audio
                audio_lengths = [3000] * len(audio_arrays)

        # Process text if provided - expand audio tokens
        if text is not None:
            if isinstance(text, str):
                text = [text]

            # Expand <|AUDIO|> tokens based on audio lengths
            if audio is not None:
                expanded_text = []
                audio_idx = 0

                for sample in text:
                    while self.audio_token in sample:
                        if audio_idx < len(audio_lengths):
                            audio_len = audio_lengths[audio_idx]
                            # Whisper: 3000 mel frames -> 1500 encoder frames
                            # Then downsample by 4 -> 375 tokens
                            whisper_frames = (audio_len - 1) // 2 + 1  # ~1500
                            num_audio_tokens = whisper_frames // self.downsample_factor  # ~375

                            # Expand single <|AUDIO|> to multiple tokens with markers
                            expanded = (
                                self.audio_bos_token +
                                self.audio_token * num_audio_tokens +
                                self.audio_eos_token
                            )
                            sample = sample.replace(self.audio_token, expanded, 1)
                            audio_idx += 1
                        else:
                            break
                    expanded_text.append(sample)
                text = expanded_text

            # Filter out kwargs that tokenizer doesn't accept
            tok_kwargs = {k: v for k, v in kwargs.items()
                         if k in ['padding', 'truncation', 'max_length', 'add_special_tokens']}

            text_inputs = self.tokenizer(
                text,
                return_tensors=return_tensors,
                padding=True,
                **tok_kwargs,
            )
            data["input_ids"] = text_inputs.input_ids
            if "attention_mask" in text_inputs:
                data["attention_mask"] = text_inputs.attention_mask

        return BatchFeature(data=data, tensor_type=return_tensors)

    def batch_decode(self, *args, **kwargs):
        """Decode token IDs to text."""
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """Decode token IDs to text."""
        return self.tokenizer.decode(*args, **kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        """Load processor from pretrained."""
        feature_extractor = WhisperFeatureExtractor.from_pretrained(
            "openai/whisper-large-v3"
        )
        tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            **kwargs
        )
        return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)

    def save_pretrained(self, save_directory, **kwargs):
        """Save processor."""
        self.feature_extractor.save_pretrained(save_directory, **kwargs)
        self.tokenizer.save_pretrained(save_directory, **kwargs)