File size: 4,665 Bytes
f7400bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import textgrid as tg
import numpy as np
import os
from transformers import AutoTokenizer, BertModel
from loguru import logger

def process_word_data(data_dir, word_file, args, data, f_name, selected_file, lang_model):
    """Process word/text data with support for different encoders."""
    logger.info(f"# ---- Building cache for Word {f_name} ---- #")

    if not os.path.exists(word_file):
        logger.warning(f"# ---- file not found for Word {f_name}, skip all files with the same id ---- #")
        selected_file.drop(selected_file[selected_file['id'] == f_name].index, inplace=True)
        return None

    word_save_path = f"{data_dir}{args.t_pre_encoder}/{f_name}.npy"
    if os.path.exists(word_save_path):
        data['word'] = np.load(word_save_path)
        logger.warning(f"# ---- file found cache for Word {f_name} ---- #")
        return data

    tgrid = tg.TextGrid.fromFile(word_file)
    word_data = []
    
    if args.t_pre_encoder == "bert":
        word_data = process_bert_encoding(tgrid, f_name, args)
    else:
        word_data = process_basic_encoding(tgrid, data, args, lang_model)

    data['word'] = np.array(word_data)
    os.makedirs(os.path.dirname(word_save_path), exist_ok=True)
    np.save(word_save_path, data['word'])
    return data

def process_bert_encoding(tgrid, f_name, args):
    """Process text data using BERT encoding."""
    tokenizer = AutoTokenizer.from_pretrained(
        args.data_path_1 + "hub/bert-base-uncased", 
        local_files_only=True
    )
    model = BertModel.from_pretrained(
        args.data_path_1 + "hub/bert-base-uncased", 
        local_files_only=True
    ).eval()
    
    list_word = []
    all_hidden = []
    word_token_mapping = []
    max_len = 400
    global_len = 0
    
    for i, word in enumerate(tgrid[0]):
        if i % max_len == 0 and i > 0:
            # Process current batch
            encoded_data = process_bert_batch(
                list_word, tokenizer, model, word_token_mapping, global_len
            )
            all_hidden.append(encoded_data['hidden_states'])
            global_len = encoded_data['global_len']
            list_word = []
            
        list_word.append("." if word.mark == "" else word.mark)
    
    # Process remaining words
    if list_word:
        encoded_data = process_bert_batch(
            list_word, tokenizer, model, word_token_mapping, global_len
        )
        all_hidden.append(encoded_data['hidden_states'])
    
    return np.concatenate(all_hidden, axis=0) if all_hidden else np.array([])

def process_bert_batch(word_list, tokenizer, model, word_token_mapping, global_len):
    """Process a batch of words through BERT."""
    str_word = ' '.join(word_list)
    
    # Get token mappings
    token_offsets = tokenizer.encode_plus(str_word, return_offsets_mapping=True)['offset_mapping']
    word_offsets = get_word_offsets(word_list)
    
    # Map words to tokens
    for start, end in word_offsets:
        sub_mapping = []
        for i, (start_t, end_t) in enumerate(token_offsets[1:-1]):
            if int(start) <= int(start_t) and int(end_t) <= int(end):
                sub_mapping.append(i + global_len)
        word_token_mapping.append(sub_mapping)
    
    # Get BERT embeddings
    with torch.no_grad():
        inputs = tokenizer(str_word, return_tensors="pt")
        outputs = model(**inputs)
        hidden_states = outputs.last_hidden_state.reshape(-1, 768).cpu().numpy()[1:-1, :]
    
    return {
        'hidden_states': hidden_states,
        'global_len': word_token_mapping[-1][-1] + 1 if word_token_mapping else global_len
    }

def get_word_offsets(word_list):
    """Calculate character offsets for each word in the list."""
    offsets = []
    current_pos = 0
    
    for word in word_list:
        start = current_pos
        end = start + len(word)
        offsets.append((start, end))
        current_pos = end + 1  # +1 for the space
        
    return offsets

def process_basic_encoding(tgrid, data, args, lang_model):
    """Process basic word encoding."""
    word_data = []
    for i in range(data['pose'].shape[0]):
        current_time = i/args.pose_fps
        found_word = False
        
        for word in tgrid[0]:
            if word.minTime <= current_time <= word.maxTime:
                if word.mark == " ":
                    word_data.append(lang_model.PAD_token)
                else:
                    word_data.append(lang_model.get_word_index(word.mark))
                found_word = True
                break
                
        if not found_word:
            word_data.append(lang_model.UNK_token)
            
    return word_data