Other
English
minecraft
action prediction
Llamipa / bespoke /format_annotated_jsonl.py
Kqte's picture
Upload 4 files
30b495e verified
"""
A dialogue is a list of samples, where each sample contains one new speaker turn.
takes a json of annotated minecraft games and converts to
a turn format to be used in LLAMA parsing.
NB: when creating jsonl, use '###PS' for 'predict structure'
"""
import os
import json
import jsonlines
from collections import defaultdict
def preprocess_edus(tlist):
"""
returns a list of lists, where each list contains the edus for a single turn.
Ex:
[...['6 <Buil> What is D2'],
['7 <Arch> Ah there is no stack,', '8 <Arch> pick up the washer'],...]
we see one turn contains the edu index 6, and the next turn contains the edus
with indexes 7 and 8.
NB: in a dialogue, might be best to change speakers to "Arch" and "Buil" to
reflect MSDC training data
"""
elist = []
cnt = 0
for turn in tlist:
speaker = turn['speaker'][:4]
#if needed, write code to change speaker names here
new_edus = []
for edu in turn['edus']:
new_string = str(cnt)+' '+'<'+speaker+'>'+' ' + edu
new_edus.append(new_string)
cnt += 1
elist.append(new_edus)
return elist
def get_windows(dial_turns, distance = 15):
"""
Takes the output from preprocess_edus() and
returns a list of index pairs. Each pair gives the delimiting indexes
for a window of turns whose total edus <= distance
Ex.
[(0, 11), (1, 12), (4, 13), (5, 14), ...]
Here, turns 0 through 11 contain edus <=distance, but once the edus from turn
12 are added, the window has to be adjusted in order for edus to remain <=distance.
The window must shifted from 1-12, then from 4-13, etc.
"""
edu_lens = [len(d) for d in dial_turns]
windows = []
esum = 0
first_cutoff = 0
for i, w in enumerate(edu_lens):
esum += w
if esum > distance:
first_cutoff = i - 1
break
windows.append((0, first_cutoff))
for i in range(first_cutoff + 1, len(edu_lens)):
#print(i)
esum = 0
for r in range(i, -1, -1):
esum += edu_lens[r]
if esum > distance:
# print(sum)
# print("new beg ", r+1)
windows.append((r+1,i))
break
# print(edu_lens)
# for w in windows:
# print(w)
# print(sum(edu_lens[w[0]:w[1]+1]))
return windows
def format_rels(index_list, rel_dict):
"""
Takes as input:
1. a list of lists, where each list corresponds to a dialogue
turn and contains the edu indexes for the edus in that turn.
2. a dict containing the relations for the dialogue
Returns a list of lists, where each list contains the relations
whose targets (y indexes) are the edus in each list.
Each relation is in the format [x index, y index, 'REL(x,y)'].
*NB we ignore backwards relations in the data
"""
map_rels_str = {'Comment':'COM', 'Contrast':'CONTR', 'Correction':'CORR', 'Question-answer_pair':'QAP',
'Acknowledgement':'ACK','Elaboration':'ELAB','Clarification_question':'CLARIFQ',
'Conditional':'COND', 'Continuation':'CONTIN', 'Result':'RES', 'Explanation':'EXPL',
'Q-Elab':'QELAB', 'Alternation':'ALT', 'Narration':'NARR',
'Confirmation_question':'CONFQ', 'Sequence':'SEQ'}
rel_list = []
for i in index_list:
i_list = []
slice = [s for s in rel_dict if s['y'] in i]
#find the relations that are
for s in slice:
if s['x'] < s['y']: #only take forward relations
new_s = []
new_s.append(s['x'])
new_s.append(s['y'])
#format the relation
new_s.append(map_rels_str[s['type']]+'('+ str(s['x'])+','+str(s['y']) +')')
i_list.append(new_s)
i_list = sorted(i_list, key= lambda x: x[1])
rel_list.append(i_list)
return rel_list
current_folder=os.getcwd()
data_turns_path = current_folder + '<turns>.json'
annotation_path = current_folder + '<orig_data>.json'
save_path = current_folder + '/<parser>.jsonl'
with open(data_turns_path, 'r') as j:
jfile = json.load(j)
dialogues = jfile
with open(annotation_path, 'r') as j:
jfile = json.load(j)
annotations = jfile
json_l = []
dialogue_count = 0
DISTANCE = 15
start_index = 0
for dial in dialogues:
dialogue_count += 1
dial_id = dial['id']
print(dial_id)
#if generating a test file for incremental parsing, add space marker between dialogues
#for any other files (test for gold parsing or train), remove this ---->
sample = {}
sample['PS'] = ""
sample['sample'] = "NEW DIALOGUE " + dial_id
json_l.append(sample)
#<-------------------------------
turns = preprocess_edus(dial['turns']) #preprocess game edus
windows = get_windows(turns, DISTANCE)
dial_rels = [a for a in annotations if a['id'] == dial_id][0]['relations']
turn_indexes = [[int(e.split('<')[0].strip()) for e in turn] for turn in turns]
relations = format_rels(turn_indexes, dial_rels)
#now add the relations for each turn to the original turns list
#the turns_plus_relations data structure is what we will use to create the data
turns_plus_relations = []
for i, t in enumerate(turns):
super_turn = []
super_turn.append(t)
super_turn.append(relations[i])
turns_plus_relations.append(super_turn)
#start with first window
global_context = []
structure = []
global_context.extend(turns_plus_relations[0][0]) #add 0 turn "mission has started"
for t in turns_plus_relations[1:windows[0][1]+1]: #go through each subsequent turn in first window and create a new sample
sample = {}
c = "\n".join(global_context)
n = "\n".join(t[0])
#find all the relations that have (0,n) as their indexes
rels_list = [r[2] for r in t[1]]
r = ' '.join(rels_list)
s = ' '.join(structure)
sample['PS'] = r
sample['sample'] = 'Context: ' + c + '\nStructure: ' + s + '\nNew Turn: ' + n
json_l.append(sample)
global_context.extend(t[0])
structure.extend(rels_list)
#now for each new turn added beyond the first window, we need to adjust the context window
for window in windows[1:]:
#find min index for this window
min_x = min([int(t.split('<')[0].strip()) for t in turns_plus_relations[window[0]][0]])
global_context = []
structure = []
for tw in turns_plus_relations[window[0]:window[1]]:
global_context.extend(tw[0])
#need to include only the structure with x indexes less than or equal to the new cutoff!!!
structure.extend([rel[2] for rel in tw[1] if rel[0] >= min_x])
sample = {}
c = "\n".join(global_context)
n = "\n".join(turns_plus_relations[window[1]][0])
rels_list = [r[2] for r in turns_plus_relations[window[1]][1] if r[0] >= min_x] # this will be the predicted relations, but need to ensure cutoff!!!
r = ' '.join(rels_list) #it's adding the r from the previous turn ???
s = ' '.join(structure)
sample['PS'] = r
sample['sample'] = 'Context: ' + c + '\nStructure: ' + s + '\nNew Turn: ' + n
json_l.append(sample)
#convert the dicts into json dicts for json_l
with jsonlines.open(save_path, mode='w') as writer:
for x in json_l:
writer.write(x)
print('jsonl saved for {} games'.format(dialogue_count))