File size: 1,191 Bytes
fbb20ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch


def load_and_freeze_llm(llm_version):
    from transformers import T5EncoderModel, T5Tokenizer

    tokenizer = T5Tokenizer.from_pretrained(llm_version)
    model = T5EncoderModel.from_pretrained(llm_version)
    # Freeze llm weights
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
    return model, tokenizer


def encode_text_batch(raw_text, text_encoder, tokenizer, device="cuda"):
    # raw_text - list (batch_size length) of strings with input text prompts

    with torch.no_grad():
        max_text_len = 50

        encoded = tokenizer.batch_encode_plus(
            raw_text,
            return_tensors="pt",
            padding="max_length",
            max_length=max_text_len,
            truncation=True,
        )
        input_ids = encoded.input_ids.to(device)
        attn_mask = encoded.attention_mask.to(device)

        output = text_encoder(input_ids=input_ids, attention_mask=attn_mask)
        encoded_text = output.last_hidden_state.detach()

        encoded_text = encoded_text[:, :max_text_len]
        attn_mask = attn_mask[:, :max_text_len]
        encoded_text *= attn_mask.unsqueeze(-1)

    return encoded_text