zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
import torch
def load_and_freeze_llm(llm_version):
from transformers import T5EncoderModel, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained(llm_version)
model = T5EncoderModel.from_pretrained(llm_version)
# Freeze llm weights
model.eval()
for p in model.parameters():
p.requires_grad = False
return model, tokenizer
def encode_text_batch(raw_text, text_encoder, tokenizer, device="cuda"):
# raw_text - list (batch_size length) of strings with input text prompts
with torch.no_grad():
max_text_len = 50
encoded = tokenizer.batch_encode_plus(
raw_text,
return_tensors="pt",
padding="max_length",
max_length=max_text_len,
truncation=True,
)
input_ids = encoded.input_ids.to(device)
attn_mask = encoded.attention_mask.to(device)
output = text_encoder(input_ids=input_ids, attention_mask=attn_mask)
encoded_text = output.last_hidden_state.detach()
encoded_text = encoded_text[:, :max_text_len]
attn_mask = attn_mask[:, :max_text_len]
encoded_text *= attn_mask.unsqueeze(-1)
return encoded_text