Spaces:
Runtime error
Runtime error
File size: 4,665 Bytes
f7400bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import textgrid as tg
import numpy as np
import os
from transformers import AutoTokenizer, BertModel
from loguru import logger
def process_word_data(data_dir, word_file, args, data, f_name, selected_file, lang_model):
"""Process word/text data with support for different encoders."""
logger.info(f"# ---- Building cache for Word {f_name} ---- #")
if not os.path.exists(word_file):
logger.warning(f"# ---- file not found for Word {f_name}, skip all files with the same id ---- #")
selected_file.drop(selected_file[selected_file['id'] == f_name].index, inplace=True)
return None
word_save_path = f"{data_dir}{args.t_pre_encoder}/{f_name}.npy"
if os.path.exists(word_save_path):
data['word'] = np.load(word_save_path)
logger.warning(f"# ---- file found cache for Word {f_name} ---- #")
return data
tgrid = tg.TextGrid.fromFile(word_file)
word_data = []
if args.t_pre_encoder == "bert":
word_data = process_bert_encoding(tgrid, f_name, args)
else:
word_data = process_basic_encoding(tgrid, data, args, lang_model)
data['word'] = np.array(word_data)
os.makedirs(os.path.dirname(word_save_path), exist_ok=True)
np.save(word_save_path, data['word'])
return data
def process_bert_encoding(tgrid, f_name, args):
"""Process text data using BERT encoding."""
tokenizer = AutoTokenizer.from_pretrained(
args.data_path_1 + "hub/bert-base-uncased",
local_files_only=True
)
model = BertModel.from_pretrained(
args.data_path_1 + "hub/bert-base-uncased",
local_files_only=True
).eval()
list_word = []
all_hidden = []
word_token_mapping = []
max_len = 400
global_len = 0
for i, word in enumerate(tgrid[0]):
if i % max_len == 0 and i > 0:
# Process current batch
encoded_data = process_bert_batch(
list_word, tokenizer, model, word_token_mapping, global_len
)
all_hidden.append(encoded_data['hidden_states'])
global_len = encoded_data['global_len']
list_word = []
list_word.append("." if word.mark == "" else word.mark)
# Process remaining words
if list_word:
encoded_data = process_bert_batch(
list_word, tokenizer, model, word_token_mapping, global_len
)
all_hidden.append(encoded_data['hidden_states'])
return np.concatenate(all_hidden, axis=0) if all_hidden else np.array([])
def process_bert_batch(word_list, tokenizer, model, word_token_mapping, global_len):
"""Process a batch of words through BERT."""
str_word = ' '.join(word_list)
# Get token mappings
token_offsets = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
word_offsets = get_word_offsets(word_list)
# Map words to tokens
for start, end in word_offsets:
sub_mapping = []
for i, (start_t, end_t) in enumerate(token_offsets[1:-1]):
if int(start) <= int(start_t) and int(end_t) <= int(end):
sub_mapping.append(i + global_len)
word_token_mapping.append(sub_mapping)
# Get BERT embeddings
with torch.no_grad():
inputs = tokenizer(str_word, return_tensors="pt")
outputs = model(**inputs)
hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
return {
'hidden_states': hidden_states,
'global_len': word_token_mapping[-1][-1] + 1 if word_token_mapping else global_len
}
def get_word_offsets(word_list):
"""Calculate character offsets for each word in the list."""
offsets = []
current_pos = 0
for word in word_list:
start = current_pos
end = start + len(word)
offsets.append((start, end))
current_pos = end + 1 # +1 for the space
return offsets
def process_basic_encoding(tgrid, data, args, lang_model):
"""Process basic word encoding."""
word_data = []
for i in range(data['pose'].shape[0]):
current_time = i/args.pose_fps
found_word = False
for word in tgrid[0]:
if word.minTime <= current_time <= word.maxTime:
if word.mark == " ":
word_data.append(lang_model.PAD_token)
else:
word_data.append(lang_model.get_word_index(word.mark))
found_word = True
break
if not found_word:
word_data.append(lang_model.UNK_token)
return word_data |