Borealis / modeling_borealis.py
Anonumous's picture
Update modeling_borealis.py
75db912 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WhisperModel, PreTrainedModel, WhisperFeatureExtractor
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import PyTorchModelHubMixin
from .configuration_borealis import BorealisConfig
from huggingface_hub import hf_hub_download
import os
class AudioLanguageAdapter(nn.Module):
def __init__(self, hidden_size: int, dim: int) -> None:
super().__init__()
self.w_in = nn.Linear(hidden_size, dim, bias=False)
self.gelu = nn.GELU()
self.w_out = nn.Linear(dim, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))
class BorealisForConditionalGeneration(PreTrainedModel, PyTorchModelHubMixin):
config_class = BorealisConfig
def __init__(self, config: BorealisConfig, language_model=None, tokenizer=None):
super().__init__(config)
assert tokenizer is not None, "Tokenizer надо передать в модельку"
self.encoder: WhisperModel = WhisperModel.from_pretrained(
config.whisper_encoder_name
).encoder
self.encoder.to(torch.bfloat16)
self.encoder.eval()
for p in self.encoder.parameters():
p.requires_grad = False
self.llm = language_model
self.tokenizer = tokenizer
self.llm.resize_token_embeddings(len(tokenizer))
self.downsample_factor = config.downsample_factor
self.adapter = AudioLanguageAdapter(
hidden_size=self.encoder.config.d_model * self.downsample_factor,
dim=self.llm.config.hidden_size,
)
self.adapter.to(torch.bfloat16)
self.bos_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
self.audio_start_id = tokenizer.convert_tokens_to_ids("<|start_of_audio|>")
self.audio_end_id = tokenizer.convert_tokens_to_ids("<|end_of_audio|>")
def _downsample(self, seq: torch.Tensor) -> torch.Tensor:
k, (T, d) = self.downsample_factor, seq.shape
target = k * math.ceil(T / k)
if target != T:
seq = F.pad(seq, (0, 0, 0, target - T))
return seq.contiguous().view(target // k, d * k)
def _tok_embed(self, tok_id: int, batch: int, device) -> torch.Tensor:
idx = torch.full((batch, 1), tok_id, dtype=torch.long, device=device)
return self.llm.get_input_embeddings()(idx)
def forward(
self,
mel: torch.Tensor,
audio_att_mask: torch.Tensor,
labels: torch.Tensor,
text_att_mask: torch.Tensor,
):
B, device = mel.size(0), mel.device
enc_out = self.encoder(
input_features=mel, attention_mask=None, return_dict=True
).last_hidden_state
audio_embs, audio_mask, max_T = [], [], 0
for seq in enc_out:
ds = self._downsample(seq)
audio_embs.append(ds)
max_T = max(max_T, ds.size(0))
for ds in audio_embs:
pad = max_T - ds.size(0)
audio_mask.append(
torch.cat(
[
torch.ones(ds.size(0), dtype=torch.long, device=device),
torch.zeros(pad, dtype=torch.long, device=device),
]
)
)
if pad:
ds = F.pad(ds, (0, 0, 0, pad))
audio_embeddings = torch.stack(audio_embs, 0)
audio_mask = torch.stack(audio_mask, 0)
audio_embeddings = self.adapter(audio_embeddings)
text_embeddings = self.llm.get_input_embeddings()(labels)
sa_positions = (labels == self.audio_start_id).nonzero(as_tuple=True)
ea_positions = (labels == self.audio_end_id).nonzero(as_tuple=True)
inputs_embeds = []
att_mask = []
for b in range(B):
sa_idx = sa_positions[1][sa_positions[0] == b].item()
ea_idx = ea_positions[1][ea_positions[0] == b].item()
prefix_emb = text_embeddings[b, : sa_idx + 1]
postfix_emb = text_embeddings[b, ea_idx:]
emb = torch.cat([prefix_emb, audio_embeddings[b], postfix_emb], dim=0)
prefix_mask = text_att_mask[b, : sa_idx + 1]
postfix_mask = text_att_mask[b, ea_idx:]
full_mask = torch.cat([prefix_mask, audio_mask[b], postfix_mask], dim=0)
inputs_embeds.append(emb)
att_mask.append(full_mask)
inputs_embeds = torch.nn.utils.rnn.pad_sequence(
inputs_embeds, batch_first=True, padding_value=0.0
)
att_mask = torch.nn.utils.rnn.pad_sequence(
att_mask, batch_first=True, padding_value=0
)
assistant_prompt = self.tokenizer(
"<|im_start|>assistant\n", add_special_tokens=False
).input_ids
assistant_starts = []
for b in range(B):
seq = labels[b]
for i in range(len(seq) - len(assistant_prompt)):
if torch.equal(
seq[i : i + len(assistant_prompt)],
torch.tensor(assistant_prompt, device=device),
):
assistant_start = i + len(assistant_prompt)
break
else:
raise ValueError("Assistant prompt not found")
assistant_starts.append(assistant_start + (ea_idx - sa_idx - 1) + max_T)
max_len = inputs_embeds.size(1)
loss_labels = labels.new_full((B, max_len), -100)
for b in range(B):
orig_assist_start = assistant_starts[b] - max_T - (ea_idx - sa_idx - 1)
content_len = len(labels[b]) - orig_assist_start
loss_labels[b, assistant_starts[b] : assistant_starts[b] + content_len] = (
labels[b, orig_assist_start:]
)
if self.tokenizer.pad_token_id is not None:
loss_labels[loss_labels == self.tokenizer.pad_token_id] = -100
out = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=att_mask,
labels=loss_labels,
return_dict=True,
)
return out.loss, out.logits
@torch.no_grad()
def generate(
self,
mel: torch.Tensor,
att_mask: torch.Tensor,
max_new_tokens: int = 512,
**kwargs,
):
return_tokens = kwargs.pop("return_tokens", False)
single = mel.dim() == 2
if single:
mel, att_mask = mel.unsqueeze(0), att_mask.unsqueeze(0)
mel = mel.to(torch.bfloat16)
B, device = mel.size(0), mel.device
enc_out = self.encoder(
input_features=mel, attention_mask=None, return_dict=True
).last_hidden_state
audio_embs, audio_mask, max_T = [], [], 0
for seq in enc_out:
ds = self._downsample(seq)
audio_embs.append(ds)
max_T = max(max_T, ds.size(0))
for i, ds in enumerate(audio_embs):
pad = max_T - ds.size(0)
audio_mask.append(
torch.cat(
[
torch.ones(ds.size(0), dtype=torch.long, device=device),
torch.zeros(pad, dtype=torch.long, device=device),
]
)
)
if pad:
audio_embs[i] = F.pad(ds, (0, 0, 0, pad))
audio_embeddings = torch.stack(audio_embs, 0)
audio_mask = torch.stack(audio_mask, 0)
audio_embeddings = self.adapter(audio_embeddings)
messages = [
{
"role": "system",
"content": "Вы полезный помощник по автоматическому распознаванию речи. Точно транскрибируйте аудио в текст.",
},
{
"role": "user",
"content": "Транскрибируйте это аудио: <|start_of_audio|><|end_of_audio|>",
},
]
chat_text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = self.tokenizer(chat_text, return_tensors="pt").to(device)
input_ids = model_inputs.input_ids.repeat(B, 1)
text_att_mask = model_inputs.attention_mask.repeat(B, 1)
text_embeddings = self.llm.get_input_embeddings()(input_ids)
sa_idx = (input_ids[0] == self.audio_start_id).nonzero(as_tuple=True)[0].item()
ea_idx = (input_ids[0] == self.audio_end_id).nonzero(as_tuple=True)[0].item()
inputs_embeds = []
full_att_mask = []
for b in range(B):
prefix_emb = text_embeddings[b, : sa_idx + 1]
postfix_emb = text_embeddings[b, ea_idx:]
emb = torch.cat([prefix_emb, audio_embeddings[b], postfix_emb], dim=0)
prefix_mask = text_att_mask[b, : sa_idx + 1]
postfix_mask = text_att_mask[b, ea_idx:]
mask = torch.cat([prefix_mask, audio_mask[b], postfix_mask], dim=0)
inputs_embeds.append(emb)
full_att_mask.append(mask)
inputs_embeds = torch.nn.utils.rnn.pad_sequence(
inputs_embeds, batch_first=True, padding_value=0.0
)
att_mask = torch.nn.utils.rnn.pad_sequence(
full_att_mask, batch_first=True, padding_value=0
)
gen_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=att_mask,
max_new_tokens=max_new_tokens,
eos_token_id=self.tokenizer.eos_token_id,
**kwargs,
)
if return_tokens:
return gen_ids[0] if single else gen_ids
else:
txt = self.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
if single:
return txt[0]
else:
return [t for t in txt]
def save_pretrained(self, save_directory, **kwargs):
os.makedirs(save_directory, exist_ok=True)
self.config.save_pretrained(save_directory)
state_dict = self.state_dict()
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
self.tokenizer.save_pretrained(save_directory)
extractor = WhisperFeatureExtractor.from_pretrained(
self.config.whisper_encoder_name
)
extractor.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = BorealisConfig.from_pretrained(pretrained_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
language_model = AutoModelForCausalLM.from_pretrained(config.llm_name)
model = cls(config, language_model=language_model, tokenizer=tokenizer)
state_dict_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin"
)
state_dict = torch.load(state_dict_path, map_location="cpu")
model.load_state_dict(state_dict)
return model
BorealisForConditionalGeneration.register_for_auto_class("AutoModelForCausalLM")