Medal-S-V1.0 / model /tokenizer.py
spc819's picture
Upload 69 files
7f3dfd7 verified
import torch
from typing import Union, List
from transformers import AutoTokenizer
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class MyTokenizer():
def __init__(self, tokenizer, max_length=256):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.max_length = max_length
def tokenize(self, texts:[str, List[str]]) -> torch.LongTensor:
"""
tokenize a lits of strings or a single string, pad/trunctate to max length input of the text tower
Args:
texts (str, List[str]]): a string
Returns:
torch.LongTensor: the tokenized tensor and the attention mask(mask out paddings)
"""
if isinstance(texts, str):
texts = [texts]
sot_token = '[CLS]'
eot_token = '[SEP]'
all_token_ids = []
max_len_in_this_batch = 0
for text in texts: # a string
tokens = [sot_token] + self.tokenizer.tokenize(text) + [eot_token] # list of str
if len(tokens) > max_len_in_this_batch:
max_len_in_this_batch = len(tokens)
all_token_ids.append(self.tokenizer.convert_tokens_to_ids(tokens))
if max_len_in_this_batch > self.max_length:
max_len_in_this_batch = self.max_length
result = torch.zeros(len(all_token_ids), max_len_in_this_batch, dtype=torch.long)
for i, token_ids in enumerate(all_token_ids): # list of int
if len(token_ids) > max_len_in_this_batch:
token_ids = token_ids[:max_len_in_this_batch] # Truncate
token_ids[-1] = self.tokenizer.convert_tokens_to_ids('[SEP]')
result[i, :len(token_ids)] = torch.tensor(token_ids)
attn_mask = torch.where(result>0, 1, 0)
return {'input_ids':result, 'attention_mask':attn_mask}