Other
English
minecraft
action prediction
File size: 7,707 Bytes
30b495e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
A dialogue is a list of samples, where each sample contains one new speaker turn. 

takes a json of annotated minecraft games and converts to 
a turn format to be used in LLAMA parsing. 

NB: when creating jsonl, use '###PS' for 'predict structure'
"""

import os
import json
import jsonlines
from collections import defaultdict


def preprocess_edus(tlist):
    """
    returns a list of lists, where each list contains the edus for a single turn. 
    Ex:

    [...['6 <Buil> What is D2'], 
    ['7 <Arch> Ah there is no stack,', '8 <Arch> pick up the washer'],...]

    we see one turn contains the edu index 6, and the next turn contains the edus
    with indexes 7 and 8.

    NB: in a dialogue, might be best to change speakers to "Arch" and "Buil" to 
    reflect MSDC training data
    """
    elist = []
    cnt = 0
    for turn in tlist:
        speaker = turn['speaker'][:4]
        #if needed, write code to change speaker names here
        new_edus = []
        for edu in turn['edus']:
            new_string = str(cnt)+' '+'<'+speaker+'>'+' ' + edu
            new_edus.append(new_string)
            cnt += 1
        elist.append(new_edus)

    return elist

def get_windows(dial_turns, distance = 15):
    """
    Takes the output from preprocess_edus() and 
    returns a list of index pairs. Each pair gives the delimiting indexes 
    for a window of turns whose total edus <= distance

    Ex. 
    [(0, 11), (1, 12), (4, 13), (5, 14), ...]

    Here, turns 0 through 11 contain edus <=distance, but once the edus from turn
    12 are added, the window has to be adjusted in order for edus to remain <=distance. 
    The window must shifted from 1-12, then from 4-13, etc. 

    """
    edu_lens = [len(d) for d in dial_turns]
    windows = []
    esum = 0
    first_cutoff = 0
    for i, w in enumerate(edu_lens):
        esum += w
        if esum > distance:
            first_cutoff = i - 1
            break
    windows.append((0, first_cutoff))

    
    for i in range(first_cutoff + 1, len(edu_lens)):
        #print(i)
        esum = 0
        for r in range(i, -1, -1):
            esum += edu_lens[r]
            if esum > distance:
                # print(sum)
                # print("new beg ", r+1)
                windows.append((r+1,i))
                break
    
    # print(edu_lens)
    # for w in windows:
    #     print(w)
    #     print(sum(edu_lens[w[0]:w[1]+1]))
    return windows

def format_rels(index_list, rel_dict):
    """
    Takes as input: 
    1. a list of lists, where each list corresponds to a dialogue
    turn and contains the edu indexes for the edus in that turn. 
    2. a dict containing the relations for the dialogue

    Returns a list of lists, where each list contains the relations
    whose targets (y indexes) are the edus in each list. 

    Each relation is in the format [x index, y index, 'REL(x,y)'].

    *NB we ignore backwards relations in the data
    """
    map_rels_str = {'Comment':'COM', 'Contrast':'CONTR', 'Correction':'CORR', 'Question-answer_pair':'QAP', 
                    'Acknowledgement':'ACK','Elaboration':'ELAB','Clarification_question':'CLARIFQ', 
                    'Conditional':'COND', 'Continuation':'CONTIN', 'Result':'RES', 'Explanation':'EXPL', 
                    'Q-Elab':'QELAB', 'Alternation':'ALT', 'Narration':'NARR', 
                    'Confirmation_question':'CONFQ', 'Sequence':'SEQ'}
    
    rel_list = []
    for i in index_list:
        i_list = []
        slice = [s for s in rel_dict if s['y'] in i]
        #find the relations that are  
        for s in slice:
            if s['x'] < s['y']: #only take forward relations
                new_s = []
                new_s.append(s['x'])
                new_s.append(s['y'])
                #format the relation 
                new_s.append(map_rels_str[s['type']]+'('+ str(s['x'])+','+str(s['y']) +')')
                i_list.append(new_s)
                i_list = sorted(i_list, key= lambda x: x[1])
        rel_list.append(i_list)

    return rel_list


current_folder=os.getcwd()

data_turns_path = current_folder + '<turns>.json'
annotation_path = current_folder + '<orig_data>.json'
save_path = current_folder + '/<parser>.jsonl'

with open(data_turns_path, 'r') as j:
    jfile = json.load(j)
    dialogues = jfile

with open(annotation_path, 'r') as j:
    jfile = json.load(j)
    annotations = jfile

json_l = []

dialogue_count = 0

DISTANCE = 15
start_index = 0


for dial in dialogues:
    dialogue_count += 1
    dial_id = dial['id']
    print(dial_id)

    #if generating a test file for incremental parsing, add space marker between dialogues
    #for any other files (test for gold parsing or train), remove this ---->
    sample = {}
    sample['PS'] = ""
    sample['sample'] = "NEW DIALOGUE " + dial_id
    json_l.append(sample)
    #<-------------------------------

    turns = preprocess_edus(dial['turns']) #preprocess game edus

    windows = get_windows(turns, DISTANCE)
   
    dial_rels = [a for a in annotations if a['id'] == dial_id][0]['relations']

    turn_indexes = [[int(e.split('<')[0].strip()) for e in turn] for turn in turns]

    relations = format_rels(turn_indexes, dial_rels)

    #now add the relations for each turn to the original turns list
    #the turns_plus_relations data structure is what we will use to create the data
    turns_plus_relations = []
    for i, t in enumerate(turns):
        super_turn = []
        super_turn.append(t)
        super_turn.append(relations[i])
        turns_plus_relations.append(super_turn)

    #start with first window
    global_context = []
    structure = []
    global_context.extend(turns_plus_relations[0][0]) #add 0 turn "mission has started"
    for t in turns_plus_relations[1:windows[0][1]+1]: #go through each subsequent turn in first window and create a new sample
        sample = {}
        c = "\n".join(global_context) 
        n = "\n".join(t[0])

        #find all the relations that have (0,n) as their indexes
        rels_list = [r[2] for r in t[1]]
        r = ' '.join(rels_list)
        s = ' '.join(structure)
           
        sample['PS'] = r
        sample['sample'] = 'Context: ' + c + '\nStructure: ' + s + '\nNew Turn: ' + n
        json_l.append(sample)
    
        global_context.extend(t[0])
        structure.extend(rels_list)

    #now for each new turn added beyond the first window, we need to adjust the context window
    for window in windows[1:]:

        #find min index for this window
        min_x = min([int(t.split('<')[0].strip()) for t in turns_plus_relations[window[0]][0]])

        global_context = []
        structure = []
        for tw in turns_plus_relations[window[0]:window[1]]:
            global_context.extend(tw[0])
            #need to include only the structure with x indexes less than or equal to the new cutoff!!!
            structure.extend([rel[2] for rel in tw[1] if rel[0] >= min_x])

        sample = {}
        c = "\n".join(global_context)
        n = "\n".join(turns_plus_relations[window[1]][0])

        rels_list = [r[2] for r in turns_plus_relations[window[1]][1] if r[0] >= min_x] # this will be the predicted relations, but need to ensure cutoff!!!
        r = ' '.join(rels_list) #it's adding the r from the previous turn ???
        s = ' '.join(structure)
           
        sample['PS'] = r
        sample['sample'] = 'Context: ' + c + '\nStructure: ' + s + '\nNew Turn: ' + n
        json_l.append(sample)
      
                   
#convert the dicts into json dicts for json_l
with jsonlines.open(save_path, mode='w') as writer:
    for x in json_l:
        writer.write(x)

print('jsonl saved for {} games'.format(dialogue_count))