File size: 4,300 Bytes
704d932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Borealis Processor for HuggingFace/vLLM compatibility.

Handles audio feature extraction and tokenization.
"""

from typing import List, Optional, Union

import torch
from transformers import ProcessorMixin, BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from transformers import AutoTokenizer


class BorealisProcessor(ProcessorMixin):
    """
    Processor for Borealis audio-language model.

    Combines WhisperFeatureExtractor for audio and Qwen3 tokenizer for text.
    """

    attributes = ["feature_extractor", "tokenizer"]
    feature_extractor_class = "WhisperFeatureExtractor"
    tokenizer_class = "AutoTokenizer"

    # Audio tokens
    audio_token = "<|AUDIO|>"
    audio_bos_token = "<|start_of_audio|>"
    audio_eos_token = "<|end_of_audio|>"

    def __init__(
        self,
        feature_extractor: Optional[WhisperFeatureExtractor] = None,
        tokenizer: Optional[AutoTokenizer] = None,
        **kwargs
    ):
        if feature_extractor is None:
            feature_extractor = WhisperFeatureExtractor.from_pretrained(
                "openai/whisper-large-v3"
            )
        if tokenizer is None:
            tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

        super().__init__(feature_extractor, tokenizer)

    @property
    def sampling_rate(self) -> int:
        return self.feature_extractor.sampling_rate

    def __call__(
        self,
        text: Optional[Union[str, List[str]]] = None,
        audio: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
        sampling_rate: Optional[int] = None,
        return_tensors: Optional[str] = "pt",
        **kwargs,
    ) -> BatchFeature:
        """
        Process text and/or audio inputs.

        Args:
            text: Text prompt(s)
            audio: Audio waveform(s) at 16kHz
            sampling_rate: Audio sampling rate (default: 16000)
            return_tensors: Return tensor type

        Returns:
            BatchFeature with input_ids and optionally input_features
        """
        if sampling_rate is None:
            sampling_rate = self.sampling_rate

        data = {}

        # Process audio if provided
        if audio is not None:
            if isinstance(audio, torch.Tensor):
                audio = [audio]

            # Convert to numpy for feature extractor
            audio_arrays = []
            for a in audio:
                if isinstance(a, torch.Tensor):
                    a = a.numpy()
                audio_arrays.append(a)

            audio_features = self.feature_extractor(
                audio_arrays,
                sampling_rate=sampling_rate,
                return_tensors=return_tensors,
            )
            data["input_features"] = audio_features.input_features

        # Process text if provided
        if text is not None:
            if isinstance(text, str):
                text = [text]

            text_inputs = self.tokenizer(
                text,
                return_tensors=return_tensors,
                padding=True,
                **kwargs,
            )
            data["input_ids"] = text_inputs.input_ids
            if "attention_mask" in text_inputs:
                data["attention_mask"] = text_inputs.attention_mask

        return BatchFeature(data=data, tensor_type=return_tensors)

    def batch_decode(self, *args, **kwargs):
        """Decode token IDs to text."""
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """Decode token IDs to text."""
        return self.tokenizer.decode(*args, **kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        """Load processor from pretrained."""
        feature_extractor = WhisperFeatureExtractor.from_pretrained(
            "openai/whisper-large-v3"
        )
        tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            **kwargs
        )
        return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)

    def save_pretrained(self, save_directory, **kwargs):
        """Save processor."""
        self.feature_extractor.save_pretrained(save_directory, **kwargs)
        self.tokenizer.save_pretrained(save_directory, **kwargs)