| import torch | |
| def load_and_freeze_llm(llm_version): | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| tokenizer = T5Tokenizer.from_pretrained(llm_version) | |
| model = T5EncoderModel.from_pretrained(llm_version) | |
| # Freeze llm weights | |
| model.eval() | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| return model, tokenizer | |
| def encode_text_batch(raw_text, text_encoder, tokenizer, device="cuda"): | |
| # raw_text - list (batch_size length) of strings with input text prompts | |
| with torch.no_grad(): | |
| max_text_len = 50 | |
| encoded = tokenizer.batch_encode_plus( | |
| raw_text, | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=max_text_len, | |
| truncation=True, | |
| ) | |
| input_ids = encoded.input_ids.to(device) | |
| attn_mask = encoded.attention_mask.to(device) | |
| output = text_encoder(input_ids=input_ids, attention_mask=attn_mask) | |
| encoded_text = output.last_hidden_state.detach() | |
| encoded_text = encoded_text[:, :max_text_len] | |
| attn_mask = attn_mask[:, :max_text_len] | |
| encoded_text *= attn_mask.unsqueeze(-1) | |
| return encoded_text | |