music_text / text_analysis.py
mechtnet's picture
Update text_analysis.py
765c670 verified
raw
history blame
920 Bytes
import torch
from transformers import LongformerModel, LongformerTokenizerFast
model = LongformerModel.from_pretrained('kazzand/ru-longformer-base-4096')
tokenizer = LongformerTokenizerFast.from_pretrained('kazzand/ru-longformer-base-4096')
def get_cls_embedding(text, model, tokenizer, device='cpu'):
model.to(device)
batch = tokenizer(text, return_tensors='pt')
# Устанавливаем глобальное внимание для CLS токена
global_attention_mask = [
[1 if token_id == tokenizer.cls_token_id else 0 for token_id in input_ids]
for input_ids in batch["input_ids"]
]
# Добавляем глобальную маску внимания в пакет
batch["global_attention_mask"] = torch.tensor(global_attention_mask)
with torch.no_grad():
output = model(**batch.to(device))
return output.last_hidden_state[:,0,:]