AlexWortega commited on
Commit
83a9bad
·
verified ·
1 Parent(s): 08a6dff

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +174 -0
  2. config.json +16 -0
  3. model.safetensors +3 -0
  4. modeling_borealis.py +436 -0
  5. push_model.py +63 -0
README.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - ru
5
+ - en
6
+ pipeline_tag: audio-text-to-text
7
+ tags:
8
+ - audio
9
+ - speech
10
+ - multimodal
11
+ - whisper
12
+ - qwen
13
+ library_name: transformers
14
+ ---
15
+
16
+ # Borealis-5B-IT
17
+
18
+ Borealis is an audio-language model that combines Whisper encoder with Qwen3-4B LLM for speech understanding and instruction-following tasks.
19
+
20
+ ## Model Description
21
+
22
+ - **Audio Encoder**: Whisper Large V3 (frozen)
23
+ - **Language Model**: Qwen3-4B (fine-tuned)
24
+ - **Adapter**: 2-layer MLP projecting audio embeddings to LLM space
25
+ - **Total Parameters**: ~5B
26
+ - **Languages**: Russian, English
27
+
28
+ ## Installation
29
+
30
+ ```bash
31
+ pip install transformers torch torchaudio safetensors
32
+ ```
33
+
34
+ ## Quick Start
35
+
36
+ ```python
37
+ import torch
38
+ import torchaudio
39
+ from transformers import AutoModel
40
+
41
+ # Load model
42
+ model = AutoModel.from_pretrained(
43
+ "Vikhrmodels/Borealis-5b-it",
44
+ trust_remote_code=True,
45
+ device="cuda"
46
+ )
47
+ model.eval()
48
+
49
+ # Load audio
50
+ audio, sr = torchaudio.load("your_audio.wav")
51
+ if sr != 16000:
52
+ audio = torchaudio.functional.resample(audio, sr, 16000)
53
+ audio = audio.squeeze()
54
+
55
+ # Generate response
56
+ with torch.inference_mode():
57
+ output_ids = model.generate(
58
+ audio=audio,
59
+ user_prompt="What is being said in this audio? <|start_of_audio|><|end_of_audio|>",
60
+ system_prompt="You are a helpful voice assistant.",
61
+ max_new_tokens=256,
62
+ temperature=0.7,
63
+ )
64
+
65
+ response = model.decode(output_ids[0])
66
+ print(response)
67
+ ```
68
+
69
+ ## Prompt Examples
70
+
71
+ ### Audio Transcription
72
+ ```python
73
+ output = model.generate(
74
+ audio=audio,
75
+ user_prompt="Transcribe this audio: <|start_of_audio|><|end_of_audio|>",
76
+ system_prompt="You are a speech recognition assistant. Accurately transcribe audio to text."
77
+ )
78
+ ```
79
+
80
+ ### Audio Summarization
81
+ ```python
82
+ output = model.generate(
83
+ audio=audio,
84
+ user_prompt="Summarize what is said in this recording: <|start_of_audio|><|end_of_audio|>",
85
+ system_prompt="You are a helpful voice assistant."
86
+ )
87
+ ```
88
+
89
+ ### Audio Q&A (Russian)
90
+ ```python
91
+ output = model.generate(
92
+ audio=audio,
93
+ user_prompt="О чём говорится в этой аудиозаписи? <|start_of_audio|><|end_of_audio|>",
94
+ system_prompt="Ты полезный голосовой ассистент."
95
+ )
96
+ ```
97
+
98
+ ### Content Description
99
+ ```python
100
+ output = model.generate(
101
+ audio=audio,
102
+ user_prompt="Describe in detail what you hear: <|start_of_audio|><|end_of_audio|>",
103
+ system_prompt="You are an attentive listener."
104
+ )
105
+ ```
106
+
107
+ ### Emotion Analysis
108
+ ```python
109
+ output = model.generate(
110
+ audio=audio,
111
+ user_prompt="What emotions does the speaker express? <|start_of_audio|><|end_of_audio|>",
112
+ system_prompt="You are an expert in audio analysis."
113
+ )
114
+ ```
115
+
116
+ ## Training Data
117
+
118
+ The model was fine-tuned on a diverse mix of audio-instruction datasets:
119
+
120
+ | Dataset | Description | Size |
121
+ |---------|-------------|------|
122
+ | [Vikhrmodels/Speech-Instructions](https://huggingface.co/datasets/Vikhrmodels/Speech-Instructions) | General speech instruction-following | 70k |
123
+ | [Vikhrmodels/Speech-Describe](https://huggingface.co/datasets/Vikhrmodels/Speech-Describe) | Audio description tasks (speech & non-speech) | ~2M |
124
+ | [Vikhrmodels/ToneBooks](https://huggingface.co/datasets/Vikhrmodels/ToneBooks) | Russian audiobook excerpts | - |
125
+ | [Vikhrmodels/AudioBooksInstructGemini2.5](https://huggingface.co/datasets/Vikhrmodels/AudioBooksInstructGemini2.5) | Instruction data generated with Gemini 2.5 | - |
126
+
127
+ ## Model Architecture
128
+
129
+ ```
130
+ Audio Input (16kHz)
131
+
132
+
133
+ ┌─────────────────┐
134
+ │ Whisper Large V3│ (Frozen)
135
+ │ Encoder │
136
+ └────────┬────────┘
137
+ │ (1280-dim embeddings)
138
+
139
+ ┌─────────────────┐
140
+ │ Downsampler │ (4x temporal reduction)
141
+ │ + Adapter │
142
+ └────────┬────────┘
143
+ │ (2560-dim embeddings)
144
+
145
+ ┌─────────────────┐
146
+ │ Qwen3-4B │ (Fine-tuned)
147
+ │ LLM │
148
+ └────────┬────────┘
149
+
150
+
151
+ Text Output
152
+ ```
153
+
154
+ ## Limitations
155
+
156
+ - Optimized for audio up to 30 seconds
157
+ - Best performance on Russian and English
158
+ - May not handle heavily noisy audio well
159
+
160
+ ## Citation
161
+
162
+ ```bibtex
163
+ @misc{borealis2025,
164
+ title={Borealis: Audio-Language Model for Speech Understanding},
165
+ author={VikhrModels},
166
+ year={2025},
167
+ publisher={HuggingFace},
168
+ url={https://huggingface.co/Vikhrmodels/Borealis-5b-it}
169
+ }
170
+ ```
171
+
172
+ ## License
173
+
174
+ Apache 2.0
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["BorealisForConditionalGeneration"],
3
+ "model_type": "borealis",
4
+ "whisper_model_name": "openai/whisper-large-v3",
5
+ "llm_model_name": "Qwen/Qwen3-4B",
6
+ "downsample_factor": 4,
7
+ "audio_hidden_size": 1280,
8
+ "llm_hidden_size": 2560,
9
+ "torch_dtype": "bfloat16",
10
+ "auto_map": {
11
+ "AutoConfig": "modeling_borealis.BorealisConfig",
12
+ "AutoModel": "modeling_borealis.BorealisForConditionalGeneration",
13
+ "AutoModelForCausalLM": "modeling_borealis.BorealisForConditionalGeneration"
14
+ },
15
+ "transformers_version": "4.48.0"
16
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4de2e0360cdf08396a69adb3c7f78c3db5a27998ce7effbe02f57676649a82b
3
+ size 10133496400
modeling_borealis.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Borealis: Audio-Language Model for Speech Understanding
3
+
4
+ This model combines a Whisper encoder with a Qwen3 LLM for audio understanding tasks.
5
+ """
6
+
7
+ import math
8
+ from typing import Optional, List, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from transformers import (
14
+ PreTrainedModel,
15
+ PretrainedConfig,
16
+ WhisperModel,
17
+ WhisperFeatureExtractor,
18
+ Qwen3ForCausalLM,
19
+ AutoTokenizer,
20
+ )
21
+
22
+
23
+ class BorealisConfig(PretrainedConfig):
24
+ """Configuration class for Borealis model."""
25
+
26
+ model_type = "borealis"
27
+
28
+ def __init__(
29
+ self,
30
+ whisper_model_name: str = "openai/whisper-large-v3",
31
+ llm_model_name: str = "Qwen/Qwen3-4B",
32
+ downsample_factor: int = 4,
33
+ audio_hidden_size: int = 1280,
34
+ llm_hidden_size: int = 2560,
35
+ torch_dtype: str = "bfloat16",
36
+ **kwargs,
37
+ ):
38
+ super().__init__(**kwargs)
39
+ self.whisper_model_name = whisper_model_name
40
+ self.llm_model_name = llm_model_name
41
+ self.downsample_factor = downsample_factor
42
+ self.audio_hidden_size = audio_hidden_size
43
+ self.llm_hidden_size = llm_hidden_size
44
+ self.torch_dtype = torch_dtype
45
+
46
+
47
+ class AudioLanguageAdapter(nn.Module):
48
+ """Adapter module that projects audio embeddings to LLM embedding space."""
49
+
50
+ def __init__(self, hidden_size: int, dim: int) -> None:
51
+ super().__init__()
52
+ self.w_in = nn.Linear(hidden_size, dim, bias=False)
53
+ self.gelu = nn.GELU()
54
+ self.w_out = nn.Linear(dim, dim, bias=False)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ return self.w_out(self.gelu(self.w_in(x)))
58
+
59
+
60
+ class BorealisForConditionalGeneration(PreTrainedModel):
61
+ """
62
+ Borealis model for audio-to-text generation.
63
+
64
+ Combines Whisper encoder for audio processing with Qwen3 LLM for text generation.
65
+ Supports instruction-following tasks on audio input.
66
+ """
67
+
68
+ config_class = BorealisConfig
69
+ base_model_prefix = "borealis"
70
+ supports_gradient_checkpointing = True
71
+ _no_split_modules = ["AudioLanguageAdapter"]
72
+
73
+ def __init__(self, config: BorealisConfig):
74
+ super().__init__(config)
75
+ self.config = config
76
+
77
+ # These will be loaded in from_pretrained or set manually
78
+ self.encoder = None
79
+ self.llm = None
80
+ self.tokenizer = None
81
+ self.feature_extractor = None
82
+
83
+ self.downsample_factor = config.downsample_factor
84
+
85
+ # Initialize adapter
86
+ self.adapter = AudioLanguageAdapter(
87
+ hidden_size=config.audio_hidden_size * config.downsample_factor,
88
+ dim=config.llm_hidden_size,
89
+ )
90
+
91
+ # Special token IDs (will be set after tokenizer is loaded)
92
+ self.audio_start_id = None
93
+ self.audio_end_id = None
94
+ self.im_start_id = None
95
+ self.im_end_id = None
96
+
97
+ def _setup_special_tokens(self):
98
+ """Setup special token IDs after tokenizer is loaded."""
99
+ if self.tokenizer is not None:
100
+ self.audio_start_id = self.tokenizer.convert_tokens_to_ids("<|start_of_audio|>")
101
+ self.audio_end_id = self.tokenizer.convert_tokens_to_ids("<|end_of_audio|>")
102
+ self.im_start_id = self.tokenizer.convert_tokens_to_ids("<|im_start|>")
103
+ self.im_end_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
104
+
105
+ def _downsample(self, seq: torch.Tensor) -> torch.Tensor:
106
+ """Downsample audio sequence by concatenating adjacent frames."""
107
+ k, (T, d) = self.downsample_factor, seq.shape
108
+ target = k * math.ceil(T / k)
109
+ if target != T:
110
+ seq = F.pad(seq, (0, 0, 0, target - T))
111
+ return seq.contiguous().view(target // k, d * k)
112
+
113
+ def _process_audio(self, mel) -> tuple:
114
+ """Process mel spectrograms through encoder and adapter."""
115
+ B, device = len(mel), mel[0][0].device
116
+ audio_embs = []
117
+ audio_mask = []
118
+ per_sample_T = []
119
+ max_T = 0
120
+
121
+ for b in range(B):
122
+ chunk_stack = torch.stack(mel[b])
123
+ enc_chunks = self.encoder(
124
+ input_features=chunk_stack, return_dict=True
125
+ ).last_hidden_state
126
+ enc_long = enc_chunks.view(-1, enc_chunks.size(-1))
127
+ ds_long = self._downsample(enc_long)
128
+ audio_embs.append(ds_long)
129
+ per_sample_T.append(ds_long.size(0))
130
+ max_T = max(max_T, ds_long.size(0))
131
+
132
+ for i in range(B):
133
+ pad = max_T - per_sample_T[i]
134
+ if pad > 0:
135
+ audio_embs[i] = F.pad(audio_embs[i], (0, 0, 0, pad))
136
+ audio_mask.append(
137
+ torch.ones(per_sample_T[i], dtype=torch.long, device=device)
138
+ )
139
+ audio_mask[i] = F.pad(audio_mask[i], (0, pad), value=0)
140
+ else:
141
+ audio_mask.append(
142
+ torch.ones(per_sample_T[i], dtype=torch.long, device=device)
143
+ )
144
+
145
+ audio_embeddings = torch.stack(audio_embs)
146
+ audio_mask = torch.stack(audio_mask)
147
+ audio_embeddings = self.adapter(audio_embeddings)
148
+
149
+ return audio_embeddings, audio_mask, per_sample_T
150
+
151
+ def prepare_audio(
152
+ self,
153
+ audio: Union[torch.Tensor, List[torch.Tensor]],
154
+ sampling_rate: int = 16000,
155
+ ) -> List[List[torch.Tensor]]:
156
+ """
157
+ Prepare raw audio waveforms for the model.
158
+
159
+ Args:
160
+ audio: Audio waveform(s) as tensor(s). Can be:
161
+ - Single tensor of shape (samples,)
162
+ - List of tensors
163
+ sampling_rate: Audio sampling rate (default: 16000)
164
+
165
+ Returns:
166
+ List of mel spectrogram chunks ready for the model
167
+ """
168
+ if self.feature_extractor is None:
169
+ raise ValueError("Feature extractor not loaded. Call load_components() first.")
170
+
171
+ if isinstance(audio, torch.Tensor) and audio.dim() == 1:
172
+ audio = [audio]
173
+
174
+ device = next(self.parameters()).device
175
+ mel_chunks = []
176
+
177
+ for audio_sample in audio:
178
+ if isinstance(audio_sample, torch.Tensor):
179
+ audio_np = audio_sample.cpu().numpy()
180
+ else:
181
+ audio_np = audio_sample
182
+
183
+ mel = self.feature_extractor(
184
+ audio_np,
185
+ sampling_rate=sampling_rate,
186
+ return_tensors="pt",
187
+ padding="max_length",
188
+ max_length=30 * sampling_rate,
189
+ truncation=True,
190
+ ).input_features.to(device).to(self.dtype)
191
+
192
+ mel_chunks.append([mel.squeeze(0)])
193
+
194
+ return mel_chunks
195
+
196
+ def load_components(self, device: str = "cuda"):
197
+ """
198
+ Load Whisper encoder, LLM, tokenizer, and feature extractor.
199
+
200
+ Args:
201
+ device: Device to load models on
202
+ """
203
+ dtype = getattr(torch, self.config.torch_dtype)
204
+
205
+ # Load Whisper encoder
206
+ whisper = WhisperModel.from_pretrained(
207
+ self.config.whisper_model_name,
208
+ torch_dtype=dtype,
209
+ )
210
+ self.encoder = whisper.encoder.to(device)
211
+ self.encoder.eval()
212
+ for p in self.encoder.parameters():
213
+ p.requires_grad = False
214
+
215
+ # Load feature extractor
216
+ self.feature_extractor = WhisperFeatureExtractor.from_pretrained(
217
+ self.config.whisper_model_name
218
+ )
219
+
220
+ # Load LLM
221
+ self.llm = Qwen3ForCausalLM.from_pretrained(
222
+ self.config.llm_model_name,
223
+ torch_dtype=dtype,
224
+ attn_implementation="sdpa",
225
+ ).to(device)
226
+
227
+ # Load tokenizer
228
+ self.tokenizer = AutoTokenizer.from_pretrained(
229
+ self.config.llm_model_name,
230
+ trust_remote_code=True,
231
+ )
232
+ self.tokenizer.add_special_tokens({
233
+ "additional_special_tokens": ["<|start_of_audio|>", "<|end_of_audio|>"]
234
+ })
235
+ self.llm.resize_token_embeddings(len(self.tokenizer))
236
+
237
+ # Setup special tokens
238
+ self._setup_special_tokens()
239
+
240
+ # Move adapter to device
241
+ self.adapter = self.adapter.to(device).to(dtype)
242
+
243
+ @classmethod
244
+ def from_pretrained(
245
+ cls,
246
+ pretrained_model_name_or_path: str,
247
+ *model_args,
248
+ device: str = "cuda",
249
+ load_components: bool = True,
250
+ **kwargs,
251
+ ):
252
+ """
253
+ Load a pretrained Borealis model.
254
+
255
+ Args:
256
+ pretrained_model_name_or_path: Path or HuggingFace model ID
257
+ device: Device to load on
258
+ load_components: Whether to automatically load Whisper/LLM components
259
+ **kwargs: Additional arguments passed to PreTrainedModel.from_pretrained
260
+
261
+ Returns:
262
+ BorealisForConditionalGeneration model
263
+ """
264
+ config = kwargs.pop("config", None)
265
+ if config is None:
266
+ config = BorealisConfig.from_pretrained(pretrained_model_name_or_path)
267
+
268
+ model = cls(config)
269
+
270
+ # Load adapter weights from checkpoint
271
+ import os
272
+ from safetensors.torch import load_file
273
+
274
+ if os.path.isdir(pretrained_model_name_or_path):
275
+ weights_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
276
+ if not os.path.exists(weights_path):
277
+ weights_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
278
+ else:
279
+ from huggingface_hub import hf_hub_download
280
+ try:
281
+ weights_path = hf_hub_download(
282
+ repo_id=pretrained_model_name_or_path,
283
+ filename="model.safetensors",
284
+ )
285
+ except:
286
+ weights_path = hf_hub_download(
287
+ repo_id=pretrained_model_name_or_path,
288
+ filename="pytorch_model.bin",
289
+ )
290
+
291
+ if weights_path.endswith(".safetensors"):
292
+ state_dict = load_file(weights_path)
293
+ else:
294
+ state_dict = torch.load(weights_path, map_location="cpu", weights_only=False)
295
+
296
+ # Load adapter weights
297
+ adapter_state = {
298
+ k.replace("adapter.", ""): v
299
+ for k, v in state_dict.items()
300
+ if k.startswith("adapter.")
301
+ }
302
+ model.adapter.load_state_dict(adapter_state)
303
+
304
+ if load_components:
305
+ model.load_components(device=device)
306
+
307
+ # Load encoder weights if present in checkpoint
308
+ encoder_state = {
309
+ k.replace("encoder.", ""): v
310
+ for k, v in state_dict.items()
311
+ if k.startswith("encoder.")
312
+ }
313
+ if encoder_state:
314
+ model.encoder.load_state_dict(encoder_state, strict=False)
315
+
316
+ # Load LLM weights if present
317
+ llm_state = {
318
+ k.replace("llm.", ""): v
319
+ for k, v in state_dict.items()
320
+ if k.startswith("llm.")
321
+ }
322
+ if llm_state:
323
+ model.llm.load_state_dict(llm_state, strict=False)
324
+
325
+ return model.to(device)
326
+
327
+ @torch.inference_mode()
328
+ def generate(
329
+ self,
330
+ audio: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
331
+ mel: Optional[List[List[torch.Tensor]]] = None,
332
+ system_prompt: Optional[str] = None,
333
+ user_prompt: Optional[str] = None,
334
+ max_new_tokens: int = 512,
335
+ temperature: float = 0.7,
336
+ top_p: float = 0.9,
337
+ do_sample: bool = True,
338
+ **kwargs,
339
+ ):
340
+ """
341
+ Generate text response for audio input.
342
+
343
+ Args:
344
+ audio: Raw audio waveform(s). Either audio or mel must be provided.
345
+ mel: Pre-processed mel spectrograms. Either audio or mel must be provided.
346
+ system_prompt: System prompt for the model
347
+ user_prompt: User prompt (should contain <|start_of_audio|><|end_of_audio|> tags)
348
+ max_new_tokens: Maximum tokens to generate
349
+ temperature: Sampling temperature
350
+ top_p: Top-p sampling parameter
351
+ do_sample: Whether to use sampling
352
+ **kwargs: Additional generation arguments
353
+
354
+ Returns:
355
+ Generated token IDs
356
+ """
357
+ if audio is not None:
358
+ mel = self.prepare_audio(audio)
359
+ elif mel is not None:
360
+ if not isinstance(mel, list) or len(mel) == 0 or not isinstance(mel[0], list):
361
+ mel = [mel]
362
+ mel = [[c.to(self.dtype) for c in m] for m in mel]
363
+ else:
364
+ raise ValueError("Either audio or mel must be provided")
365
+
366
+ B, device = len(mel), mel[0][0].device
367
+
368
+ audio_embeddings, audio_mask, per_sample_T = self._process_audio(mel)
369
+
370
+ if system_prompt is None:
371
+ system_prompt = "You are a helpful voice assistant. Listen to the audio and respond appropriately."
372
+ if user_prompt is None:
373
+ user_prompt = "What is being said in this audio? <|start_of_audio|><|end_of_audio|>"
374
+ elif "<|start_of_audio|>" not in user_prompt:
375
+ user_prompt = f"{user_prompt}\n<|start_of_audio|><|end_of_audio|>"
376
+
377
+ messages = [
378
+ {"role": "system", "content": system_prompt},
379
+ {"role": "user", "content": user_prompt},
380
+ ]
381
+
382
+ chat_text = self.tokenizer.apply_chat_template(
383
+ messages,
384
+ tokenize=False,
385
+ add_generation_prompt=True,
386
+ )
387
+
388
+ model_inputs = self.tokenizer(chat_text, return_tensors="pt").to(device)
389
+
390
+ input_ids = model_inputs.input_ids.repeat(B, 1)
391
+ text_att_mask = model_inputs.attention_mask.repeat(B, 1)
392
+
393
+ text_embeddings = self.llm.get_input_embeddings()(input_ids)
394
+
395
+ sa_idx = (input_ids[0] == self.audio_start_id).nonzero(as_tuple=True)[0].item()
396
+ ea_idx = (input_ids[0] == self.audio_end_id).nonzero(as_tuple=True)[0].item()
397
+
398
+ inputs_embeds = []
399
+ full_att_mask = []
400
+
401
+ for b in range(B):
402
+ prefix_emb = text_embeddings[b, : sa_idx + 1]
403
+ postfix_emb = text_embeddings[b, ea_idx:]
404
+ emb = torch.cat([prefix_emb, audio_embeddings[b], postfix_emb], dim=0)
405
+
406
+ prefix_mask = text_att_mask[b, : sa_idx + 1]
407
+ postfix_mask = text_att_mask[b, ea_idx:]
408
+ mask = torch.cat([prefix_mask, audio_mask[b], postfix_mask], dim=0)
409
+
410
+ inputs_embeds.append(emb)
411
+ full_att_mask.append(mask)
412
+
413
+ inputs_embeds = torch.nn.utils.rnn.pad_sequence(
414
+ inputs_embeds, batch_first=True, padding_value=0.0
415
+ )
416
+
417
+ att_mask = torch.nn.utils.rnn.pad_sequence(
418
+ full_att_mask, batch_first=True, padding_value=0
419
+ )
420
+
421
+ gen_ids = self.llm.generate(
422
+ inputs_embeds=inputs_embeds,
423
+ attention_mask=att_mask,
424
+ max_new_tokens=max_new_tokens,
425
+ eos_token_id=self.tokenizer.eos_token_id,
426
+ temperature=temperature,
427
+ top_p=top_p,
428
+ do_sample=do_sample,
429
+ **kwargs,
430
+ )
431
+
432
+ return gen_ids
433
+
434
+ def decode(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> str:
435
+ """Decode token IDs to text."""
436
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
push_model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Push Borealis model to HuggingFace Hub."""
2
+
3
+ import os
4
+ import torch
5
+ from huggingface_hub import HfApi, create_repo, upload_folder
6
+ from safetensors.torch import save_model
7
+
8
+ # Config
9
+ HF_REPO = "Vikhrmodels/Borealis-5b-it"
10
+ CHECKPOINT_PATH = "/home/alex/Borealis/borealis_instruct_ckpts/checkpoint-2898/pytorch_model.bin"
11
+ OUTPUT_DIR = "/home/alex/Borealis/hf_upload"
12
+
13
+ class DictModule(torch.nn.Module):
14
+ """Wrapper to use save_model with state_dict."""
15
+ def __init__(self, state_dict):
16
+ super().__init__()
17
+ for k, v in state_dict.items():
18
+ # Replace dots with underscores for valid attr names
19
+ self.register_buffer(k.replace(".", "__DOT__"), v)
20
+
21
+ def state_dict(self, *args, **kwargs):
22
+ sd = super().state_dict(*args, **kwargs)
23
+ return {k.replace("__DOT__", "."): v for k, v in sd.items()}
24
+
25
+ def main():
26
+ print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
27
+ state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=False)
28
+ print(f"Loaded {len(state_dict)} keys")
29
+
30
+ # Handle shared tensors by cloning
31
+ print("Handling shared tensors...")
32
+ new_state_dict = {}
33
+ for k, v in state_dict.items():
34
+ new_state_dict[k] = v.clone()
35
+
36
+ # Convert to safetensors using save_model
37
+ print("Converting to safetensors format...")
38
+ safetensors_path = os.path.join(OUTPUT_DIR, "model.safetensors")
39
+
40
+ from safetensors.torch import save_file
41
+ save_file(new_state_dict, safetensors_path)
42
+ print(f"Saved to {safetensors_path}")
43
+
44
+ # Create repo
45
+ print(f"\nCreating/accessing repo: {HF_REPO}")
46
+ api = HfApi()
47
+ try:
48
+ create_repo(HF_REPO, repo_type="model", exist_ok=True)
49
+ except Exception as e:
50
+ print(f"Repo note: {e}")
51
+
52
+ # Upload folder
53
+ print(f"\nUploading to {HF_REPO}...")
54
+ api.upload_folder(
55
+ folder_path=OUTPUT_DIR,
56
+ repo_id=HF_REPO,
57
+ repo_type="model",
58
+ )
59
+
60
+ print(f"\nDone! Model available at: https://huggingface.co/{HF_REPO}")
61
+
62
+ if __name__ == "__main__":
63
+ main()