| | import torch |
| | import transformers |
| | from typing import List |
| | from transformers import T5Tokenizer, T5EncoderModel, T5Config |
| | from einops import rearrange |
| |
|
| | transformers.logging.set_verbosity_error() |
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| | def default(val, d): |
| | if exists(val): |
| | return val |
| | return d() if callable(d) else d |
| |
|
| | |
| |
|
| | MAX_LENGTH = 256 |
| |
|
| | DEFAULT_T5_NAME = 'google/t5-v1_1-base' |
| |
|
| | T5_CONFIGS = {} |
| |
|
| | |
| |
|
| | def get_tokenizer(name): |
| | tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH) |
| | return tokenizer |
| |
|
| | def get_model(name): |
| | model = T5EncoderModel.from_pretrained(name) |
| | return model |
| |
|
| | def get_model_and_tokenizer(name): |
| | global T5_CONFIGS |
| |
|
| | if name not in T5_CONFIGS: |
| | T5_CONFIGS[name] = dict() |
| | if "model" not in T5_CONFIGS[name]: |
| | T5_CONFIGS[name]["model"] = get_model(name) |
| | if "tokenizer" not in T5_CONFIGS[name]: |
| | T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name) |
| |
|
| | return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer'] |
| |
|
| | def get_encoded_dim(name): |
| | if name not in T5_CONFIGS: |
| | |
| | config = T5Config.from_pretrained(name) |
| | T5_CONFIGS[name] = dict(config=config) |
| | elif "config" in T5_CONFIGS[name]: |
| | config = T5_CONFIGS[name]["config"] |
| | elif "model" in T5_CONFIGS[name]: |
| | config = T5_CONFIGS[name]["model"].config |
| | else: |
| | assert False |
| | return config.d_model |
| |
|
| | |
| |
|
| | def t5_tokenize( |
| | texts: List[str], |
| | name = DEFAULT_T5_NAME |
| | ): |
| | t5, tokenizer = get_model_and_tokenizer(name) |
| |
|
| | if torch.cuda.is_available(): |
| | t5 = t5.cuda() |
| |
|
| | device = next(t5.parameters()).device |
| |
|
| | encoded = tokenizer.batch_encode_plus( |
| | texts, |
| | return_tensors = "pt", |
| | padding = 'longest', |
| | max_length = MAX_LENGTH, |
| | truncation = True |
| | ) |
| |
|
| | input_ids = encoded.input_ids.to(device) |
| | attn_mask = encoded.attention_mask.to(device) |
| | return input_ids, attn_mask |
| |
|
| | def t5_encode_tokenized_text( |
| | token_ids, |
| | attn_mask = None, |
| | pad_id = None, |
| | name = DEFAULT_T5_NAME |
| | ): |
| | assert exists(attn_mask) or exists(pad_id) |
| | t5, _ = get_model_and_tokenizer(name) |
| |
|
| | attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long()) |
| |
|
| | t5.eval() |
| |
|
| | with torch.no_grad(): |
| | output = t5(input_ids = token_ids, attention_mask = attn_mask) |
| | encoded_text = output.last_hidden_state.detach() |
| |
|
| | attn_mask = attn_mask.bool() |
| |
|
| | encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) |
| | return encoded_text |
| |
|
| | def t5_encode_text( |
| | texts: List[str], |
| | name = DEFAULT_T5_NAME, |
| | return_attn_mask = False |
| | ): |
| | token_ids, attn_mask = t5_tokenize(texts, name = name) |
| | encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name) |
| |
|
| | if return_attn_mask: |
| | attn_mask = attn_mask.bool() |
| | return encoded_text, attn_mask |
| |
|
| | return encoded_text |
| |
|