Spaces:
Runtime error
Runtime error
| import textgrid as tg | |
| import numpy as np | |
| import os | |
| from transformers import AutoTokenizer, BertModel | |
| from loguru import logger | |
| def process_word_data(data_dir, word_file, args, data, f_name, selected_file, lang_model): | |
| """Process word/text data with support for different encoders.""" | |
| logger.info(f"# ---- Building cache for Word {f_name} ---- #") | |
| if not os.path.exists(word_file): | |
| logger.warning(f"# ---- file not found for Word {f_name}, skip all files with the same id ---- #") | |
| selected_file.drop(selected_file[selected_file['id'] == f_name].index, inplace=True) | |
| return None | |
| word_save_path = f"{data_dir}{args.t_pre_encoder}/{f_name}.npy" | |
| if os.path.exists(word_save_path): | |
| data['word'] = np.load(word_save_path) | |
| logger.warning(f"# ---- file found cache for Word {f_name} ---- #") | |
| return data | |
| tgrid = tg.TextGrid.fromFile(word_file) | |
| word_data = [] | |
| if args.t_pre_encoder == "bert": | |
| word_data = process_bert_encoding(tgrid, f_name, args) | |
| else: | |
| word_data = process_basic_encoding(tgrid, data, args, lang_model) | |
| data['word'] = np.array(word_data) | |
| os.makedirs(os.path.dirname(word_save_path), exist_ok=True) | |
| np.save(word_save_path, data['word']) | |
| return data | |
| def process_bert_encoding(tgrid, f_name, args): | |
| """Process text data using BERT encoding.""" | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| args.data_path_1 + "hub/bert-base-uncased", | |
| local_files_only=True | |
| ) | |
| model = BertModel.from_pretrained( | |
| args.data_path_1 + "hub/bert-base-uncased", | |
| local_files_only=True | |
| ).eval() | |
| list_word = [] | |
| all_hidden = [] | |
| word_token_mapping = [] | |
| max_len = 400 | |
| global_len = 0 | |
| for i, word in enumerate(tgrid[0]): | |
| if i % max_len == 0 and i > 0: | |
| # Process current batch | |
| encoded_data = process_bert_batch( | |
| list_word, tokenizer, model, word_token_mapping, global_len | |
| ) | |
| all_hidden.append(encoded_data['hidden_states']) | |
| global_len = encoded_data['global_len'] | |
| list_word = [] | |
| list_word.append("." if word.mark == "" else word.mark) | |
| # Process remaining words | |
| if list_word: | |
| encoded_data = process_bert_batch( | |
| list_word, tokenizer, model, word_token_mapping, global_len | |
| ) | |
| all_hidden.append(encoded_data['hidden_states']) | |
| return np.concatenate(all_hidden, axis=0) if all_hidden else np.array([]) | |
| def process_bert_batch(word_list, tokenizer, model, word_token_mapping, global_len): | |
| """Process a batch of words through BERT.""" | |
| str_word = ' '.join(word_list) | |
| # Get token mappings | |
| token_offsets = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping'] | |
| word_offsets = get_word_offsets(word_list) | |
| # Map words to tokens | |
| for start, end in word_offsets: | |
| sub_mapping = [] | |
| for i, (start_t, end_t) in enumerate(token_offsets[1:-1]): | |
| if int(start) <= int(start_t) and int(end_t) <= int(end): | |
| sub_mapping.append(i + global_len) | |
| word_token_mapping.append(sub_mapping) | |
| # Get BERT embeddings | |
| with torch.no_grad(): | |
| inputs = tokenizer(str_word, return_tensors="pt") | |
| outputs = model(**inputs) | |
| hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :] | |
| return { | |
| 'hidden_states': hidden_states, | |
| 'global_len': word_token_mapping[-1][-1] + 1 if word_token_mapping else global_len | |
| } | |
| def get_word_offsets(word_list): | |
| """Calculate character offsets for each word in the list.""" | |
| offsets = [] | |
| current_pos = 0 | |
| for word in word_list: | |
| start = current_pos | |
| end = start + len(word) | |
| offsets.append((start, end)) | |
| current_pos = end + 1 # +1 for the space | |
| return offsets | |
| def process_basic_encoding(tgrid, data, args, lang_model): | |
| """Process basic word encoding.""" | |
| word_data = [] | |
| for i in range(data['pose'].shape[0]): | |
| current_time = i/args.pose_fps | |
| found_word = False | |
| for word in tgrid[0]: | |
| if word.minTime <= current_time <= word.maxTime: | |
| if word.mark == " ": | |
| word_data.append(lang_model.PAD_token) | |
| else: | |
| word_data.append(lang_model.get_word_index(word.mark)) | |
| found_word = True | |
| break | |
| if not found_word: | |
| word_data.append(lang_model.UNK_token) | |
| return word_data |