|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import WhisperModel, PreTrainedModel, WhisperFeatureExtractor |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
from .configuration_borealis import BorealisConfig |
|
|
from huggingface_hub import hf_hub_download |
|
|
import os |
|
|
|
|
|
|
|
|
class AudioLanguageAdapter(nn.Module): |
|
|
def __init__(self, hidden_size: int, dim: int) -> None: |
|
|
super().__init__() |
|
|
self.w_in = nn.Linear(hidden_size, dim, bias=False) |
|
|
self.gelu = nn.GELU() |
|
|
self.w_out = nn.Linear(dim, dim, bias=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.w_out(self.gelu(self.w_in(x))) |
|
|
|
|
|
|
|
|
class BorealisForConditionalGeneration(PreTrainedModel, PyTorchModelHubMixin): |
|
|
config_class = BorealisConfig |
|
|
|
|
|
def __init__(self, config: BorealisConfig, language_model=None, tokenizer=None): |
|
|
super().__init__(config) |
|
|
assert tokenizer is not None, "Tokenizer надо передать в модельку" |
|
|
self.encoder: WhisperModel = WhisperModel.from_pretrained( |
|
|
config.whisper_encoder_name |
|
|
).encoder |
|
|
self.encoder.to(torch.bfloat16) |
|
|
self.encoder.eval() |
|
|
for p in self.encoder.parameters(): |
|
|
p.requires_grad = False |
|
|
self.llm = language_model |
|
|
self.tokenizer = tokenizer |
|
|
self.llm.resize_token_embeddings(len(tokenizer)) |
|
|
self.downsample_factor = config.downsample_factor |
|
|
self.adapter = AudioLanguageAdapter( |
|
|
hidden_size=self.encoder.config.d_model * self.downsample_factor, |
|
|
dim=self.llm.config.hidden_size, |
|
|
) |
|
|
self.adapter.to(torch.bfloat16) |
|
|
self.bos_id = tokenizer.convert_tokens_to_ids("<|im_start|>") |
|
|
self.audio_start_id = tokenizer.convert_tokens_to_ids("<|start_of_audio|>") |
|
|
self.audio_end_id = tokenizer.convert_tokens_to_ids("<|end_of_audio|>") |
|
|
|
|
|
def _downsample(self, seq: torch.Tensor) -> torch.Tensor: |
|
|
k, (T, d) = self.downsample_factor, seq.shape |
|
|
target = k * math.ceil(T / k) |
|
|
if target != T: |
|
|
seq = F.pad(seq, (0, 0, 0, target - T)) |
|
|
return seq.contiguous().view(target // k, d * k) |
|
|
|
|
|
def _tok_embed(self, tok_id: int, batch: int, device) -> torch.Tensor: |
|
|
idx = torch.full((batch, 1), tok_id, dtype=torch.long, device=device) |
|
|
return self.llm.get_input_embeddings()(idx) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
mel: torch.Tensor, |
|
|
audio_att_mask: torch.Tensor, |
|
|
labels: torch.Tensor, |
|
|
text_att_mask: torch.Tensor, |
|
|
): |
|
|
B, device = mel.size(0), mel.device |
|
|
enc_out = self.encoder( |
|
|
input_features=mel, attention_mask=None, return_dict=True |
|
|
).last_hidden_state |
|
|
audio_embs, audio_mask, max_T = [], [], 0 |
|
|
for seq in enc_out: |
|
|
ds = self._downsample(seq) |
|
|
audio_embs.append(ds) |
|
|
max_T = max(max_T, ds.size(0)) |
|
|
for ds in audio_embs: |
|
|
pad = max_T - ds.size(0) |
|
|
audio_mask.append( |
|
|
torch.cat( |
|
|
[ |
|
|
torch.ones(ds.size(0), dtype=torch.long, device=device), |
|
|
torch.zeros(pad, dtype=torch.long, device=device), |
|
|
] |
|
|
) |
|
|
) |
|
|
if pad: |
|
|
ds = F.pad(ds, (0, 0, 0, pad)) |
|
|
audio_embeddings = torch.stack(audio_embs, 0) |
|
|
audio_mask = torch.stack(audio_mask, 0) |
|
|
audio_embeddings = self.adapter(audio_embeddings) |
|
|
text_embeddings = self.llm.get_input_embeddings()(labels) |
|
|
sa_positions = (labels == self.audio_start_id).nonzero(as_tuple=True) |
|
|
ea_positions = (labels == self.audio_end_id).nonzero(as_tuple=True) |
|
|
inputs_embeds = [] |
|
|
att_mask = [] |
|
|
for b in range(B): |
|
|
sa_idx = sa_positions[1][sa_positions[0] == b].item() |
|
|
ea_idx = ea_positions[1][ea_positions[0] == b].item() |
|
|
prefix_emb = text_embeddings[b, : sa_idx + 1] |
|
|
postfix_emb = text_embeddings[b, ea_idx:] |
|
|
emb = torch.cat([prefix_emb, audio_embeddings[b], postfix_emb], dim=0) |
|
|
prefix_mask = text_att_mask[b, : sa_idx + 1] |
|
|
postfix_mask = text_att_mask[b, ea_idx:] |
|
|
full_mask = torch.cat([prefix_mask, audio_mask[b], postfix_mask], dim=0) |
|
|
inputs_embeds.append(emb) |
|
|
att_mask.append(full_mask) |
|
|
inputs_embeds = torch.nn.utils.rnn.pad_sequence( |
|
|
inputs_embeds, batch_first=True, padding_value=0.0 |
|
|
) |
|
|
att_mask = torch.nn.utils.rnn.pad_sequence( |
|
|
att_mask, batch_first=True, padding_value=0 |
|
|
) |
|
|
assistant_prompt = self.tokenizer( |
|
|
"<|im_start|>assistant\n", add_special_tokens=False |
|
|
).input_ids |
|
|
assistant_starts = [] |
|
|
for b in range(B): |
|
|
seq = labels[b] |
|
|
for i in range(len(seq) - len(assistant_prompt)): |
|
|
if torch.equal( |
|
|
seq[i : i + len(assistant_prompt)], |
|
|
torch.tensor(assistant_prompt, device=device), |
|
|
): |
|
|
assistant_start = i + len(assistant_prompt) |
|
|
break |
|
|
else: |
|
|
raise ValueError("Assistant prompt not found") |
|
|
assistant_starts.append(assistant_start + (ea_idx - sa_idx - 1) + max_T) |
|
|
max_len = inputs_embeds.size(1) |
|
|
loss_labels = labels.new_full((B, max_len), -100) |
|
|
for b in range(B): |
|
|
orig_assist_start = assistant_starts[b] - max_T - (ea_idx - sa_idx - 1) |
|
|
content_len = len(labels[b]) - orig_assist_start |
|
|
loss_labels[b, assistant_starts[b] : assistant_starts[b] + content_len] = ( |
|
|
labels[b, orig_assist_start:] |
|
|
) |
|
|
if self.tokenizer.pad_token_id is not None: |
|
|
loss_labels[loss_labels == self.tokenizer.pad_token_id] = -100 |
|
|
out = self.llm( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=att_mask, |
|
|
labels=loss_labels, |
|
|
return_dict=True, |
|
|
) |
|
|
return out.loss, out.logits |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
mel: torch.Tensor, |
|
|
att_mask: torch.Tensor, |
|
|
max_new_tokens: int = 512, |
|
|
**kwargs, |
|
|
): |
|
|
return_tokens = kwargs.pop("return_tokens", False) |
|
|
single = mel.dim() == 2 |
|
|
if single: |
|
|
mel, att_mask = mel.unsqueeze(0), att_mask.unsqueeze(0) |
|
|
mel = mel.to(torch.bfloat16) |
|
|
B, device = mel.size(0), mel.device |
|
|
enc_out = self.encoder( |
|
|
input_features=mel, attention_mask=None, return_dict=True |
|
|
).last_hidden_state |
|
|
audio_embs, audio_mask, max_T = [], [], 0 |
|
|
for seq in enc_out: |
|
|
ds = self._downsample(seq) |
|
|
audio_embs.append(ds) |
|
|
max_T = max(max_T, ds.size(0)) |
|
|
for i, ds in enumerate(audio_embs): |
|
|
pad = max_T - ds.size(0) |
|
|
audio_mask.append( |
|
|
torch.cat( |
|
|
[ |
|
|
torch.ones(ds.size(0), dtype=torch.long, device=device), |
|
|
torch.zeros(pad, dtype=torch.long, device=device), |
|
|
] |
|
|
) |
|
|
) |
|
|
if pad: |
|
|
audio_embs[i] = F.pad(ds, (0, 0, 0, pad)) |
|
|
audio_embeddings = torch.stack(audio_embs, 0) |
|
|
audio_mask = torch.stack(audio_mask, 0) |
|
|
audio_embeddings = self.adapter(audio_embeddings) |
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": "Вы полезный помощник по автоматическому распознаванию речи. Точно транскрибируйте аудио в текст.", |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": "Транскрибируйте это аудио: <|start_of_audio|><|end_of_audio|>", |
|
|
}, |
|
|
] |
|
|
chat_text = self.tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
model_inputs = self.tokenizer(chat_text, return_tensors="pt").to(device) |
|
|
input_ids = model_inputs.input_ids.repeat(B, 1) |
|
|
text_att_mask = model_inputs.attention_mask.repeat(B, 1) |
|
|
text_embeddings = self.llm.get_input_embeddings()(input_ids) |
|
|
sa_idx = (input_ids[0] == self.audio_start_id).nonzero(as_tuple=True)[0].item() |
|
|
ea_idx = (input_ids[0] == self.audio_end_id).nonzero(as_tuple=True)[0].item() |
|
|
inputs_embeds = [] |
|
|
full_att_mask = [] |
|
|
for b in range(B): |
|
|
prefix_emb = text_embeddings[b, : sa_idx + 1] |
|
|
postfix_emb = text_embeddings[b, ea_idx:] |
|
|
emb = torch.cat([prefix_emb, audio_embeddings[b], postfix_emb], dim=0) |
|
|
prefix_mask = text_att_mask[b, : sa_idx + 1] |
|
|
postfix_mask = text_att_mask[b, ea_idx:] |
|
|
mask = torch.cat([prefix_mask, audio_mask[b], postfix_mask], dim=0) |
|
|
inputs_embeds.append(emb) |
|
|
full_att_mask.append(mask) |
|
|
inputs_embeds = torch.nn.utils.rnn.pad_sequence( |
|
|
inputs_embeds, batch_first=True, padding_value=0.0 |
|
|
) |
|
|
att_mask = torch.nn.utils.rnn.pad_sequence( |
|
|
full_att_mask, batch_first=True, padding_value=0 |
|
|
) |
|
|
gen_ids = self.llm.generate( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=att_mask, |
|
|
max_new_tokens=max_new_tokens, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
**kwargs, |
|
|
) |
|
|
if return_tokens: |
|
|
return gen_ids[0] if single else gen_ids |
|
|
else: |
|
|
txt = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True) |
|
|
if single: |
|
|
return txt[0] |
|
|
else: |
|
|
return [t for t in txt] |
|
|
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
self.config.save_pretrained(save_directory) |
|
|
state_dict = self.state_dict() |
|
|
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin")) |
|
|
self.tokenizer.save_pretrained(save_directory) |
|
|
extractor = WhisperFeatureExtractor.from_pretrained( |
|
|
self.config.whisper_encoder_name |
|
|
) |
|
|
extractor.save_pretrained(save_directory) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
config = BorealisConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) |
|
|
language_model = AutoModelForCausalLM.from_pretrained(config.llm_name) |
|
|
model = cls(config, language_model=language_model, tokenizer=tokenizer) |
|
|
|
|
|
state_dict_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin" |
|
|
) |
|
|
|
|
|
state_dict = torch.load(state_dict_path, map_location="cpu") |
|
|
model.load_state_dict(state_dict) |
|
|
return model |
|
|
|
|
|
|
|
|
BorealisForConditionalGeneration.register_for_auto_class("AutoModelForCausalLM") |
|
|
|