|
|
""" |
|
|
A dialogue is a list of samples, where each sample contains one new speaker turn. |
|
|
|
|
|
takes a json of annotated minecraft games and converts to |
|
|
a turn format to be used in LLAMA parsing. |
|
|
|
|
|
NB: when creating jsonl, use '###PS' for 'predict structure' |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import jsonlines |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
def preprocess_edus(tlist): |
|
|
""" |
|
|
returns a list of lists, where each list contains the edus for a single turn. |
|
|
Ex: |
|
|
|
|
|
[...['6 <Buil> What is D2'], |
|
|
['7 <Arch> Ah there is no stack,', '8 <Arch> pick up the washer'],...] |
|
|
|
|
|
we see one turn contains the edu index 6, and the next turn contains the edus |
|
|
with indexes 7 and 8. |
|
|
|
|
|
NB: in a dialogue, might be best to change speakers to "Arch" and "Buil" to |
|
|
reflect MSDC training data |
|
|
""" |
|
|
elist = [] |
|
|
cnt = 0 |
|
|
for turn in tlist: |
|
|
speaker = turn['speaker'][:4] |
|
|
|
|
|
new_edus = [] |
|
|
for edu in turn['edus']: |
|
|
new_string = str(cnt)+' '+'<'+speaker+'>'+' ' + edu |
|
|
new_edus.append(new_string) |
|
|
cnt += 1 |
|
|
elist.append(new_edus) |
|
|
|
|
|
return elist |
|
|
|
|
|
def get_windows(dial_turns, distance = 15): |
|
|
""" |
|
|
Takes the output from preprocess_edus() and |
|
|
returns a list of index pairs. Each pair gives the delimiting indexes |
|
|
for a window of turns whose total edus <= distance |
|
|
|
|
|
Ex. |
|
|
[(0, 11), (1, 12), (4, 13), (5, 14), ...] |
|
|
|
|
|
Here, turns 0 through 11 contain edus <=distance, but once the edus from turn |
|
|
12 are added, the window has to be adjusted in order for edus to remain <=distance. |
|
|
The window must shifted from 1-12, then from 4-13, etc. |
|
|
|
|
|
""" |
|
|
edu_lens = [len(d) for d in dial_turns] |
|
|
windows = [] |
|
|
esum = 0 |
|
|
first_cutoff = 0 |
|
|
for i, w in enumerate(edu_lens): |
|
|
esum += w |
|
|
if esum > distance: |
|
|
first_cutoff = i - 1 |
|
|
break |
|
|
windows.append((0, first_cutoff)) |
|
|
|
|
|
|
|
|
for i in range(first_cutoff + 1, len(edu_lens)): |
|
|
|
|
|
esum = 0 |
|
|
for r in range(i, -1, -1): |
|
|
esum += edu_lens[r] |
|
|
if esum > distance: |
|
|
|
|
|
|
|
|
windows.append((r+1,i)) |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return windows |
|
|
|
|
|
def format_rels(index_list, rel_dict): |
|
|
""" |
|
|
Takes as input: |
|
|
1. a list of lists, where each list corresponds to a dialogue |
|
|
turn and contains the edu indexes for the edus in that turn. |
|
|
2. a dict containing the relations for the dialogue |
|
|
|
|
|
Returns a list of lists, where each list contains the relations |
|
|
whose targets (y indexes) are the edus in each list. |
|
|
|
|
|
Each relation is in the format [x index, y index, 'REL(x,y)']. |
|
|
|
|
|
*NB we ignore backwards relations in the data |
|
|
""" |
|
|
map_rels_str = {'Comment':'COM', 'Contrast':'CONTR', 'Correction':'CORR', 'Question-answer_pair':'QAP', |
|
|
'Acknowledgement':'ACK','Elaboration':'ELAB','Clarification_question':'CLARIFQ', |
|
|
'Conditional':'COND', 'Continuation':'CONTIN', 'Result':'RES', 'Explanation':'EXPL', |
|
|
'Q-Elab':'QELAB', 'Alternation':'ALT', 'Narration':'NARR', |
|
|
'Confirmation_question':'CONFQ', 'Sequence':'SEQ'} |
|
|
|
|
|
rel_list = [] |
|
|
for i in index_list: |
|
|
i_list = [] |
|
|
slice = [s for s in rel_dict if s['y'] in i] |
|
|
|
|
|
for s in slice: |
|
|
if s['x'] < s['y']: |
|
|
new_s = [] |
|
|
new_s.append(s['x']) |
|
|
new_s.append(s['y']) |
|
|
|
|
|
new_s.append(map_rels_str[s['type']]+'('+ str(s['x'])+','+str(s['y']) +')') |
|
|
i_list.append(new_s) |
|
|
i_list = sorted(i_list, key= lambda x: x[1]) |
|
|
rel_list.append(i_list) |
|
|
|
|
|
return rel_list |
|
|
|
|
|
|
|
|
current_folder=os.getcwd() |
|
|
|
|
|
data_turns_path = current_folder + '<turns>.json' |
|
|
annotation_path = current_folder + '<orig_data>.json' |
|
|
save_path = current_folder + '/<parser>.jsonl' |
|
|
|
|
|
with open(data_turns_path, 'r') as j: |
|
|
jfile = json.load(j) |
|
|
dialogues = jfile |
|
|
|
|
|
with open(annotation_path, 'r') as j: |
|
|
jfile = json.load(j) |
|
|
annotations = jfile |
|
|
|
|
|
json_l = [] |
|
|
|
|
|
dialogue_count = 0 |
|
|
|
|
|
DISTANCE = 15 |
|
|
start_index = 0 |
|
|
|
|
|
|
|
|
for dial in dialogues: |
|
|
dialogue_count += 1 |
|
|
dial_id = dial['id'] |
|
|
print(dial_id) |
|
|
|
|
|
|
|
|
|
|
|
sample = {} |
|
|
sample['PS'] = "" |
|
|
sample['sample'] = "NEW DIALOGUE " + dial_id |
|
|
json_l.append(sample) |
|
|
|
|
|
|
|
|
turns = preprocess_edus(dial['turns']) |
|
|
|
|
|
windows = get_windows(turns, DISTANCE) |
|
|
|
|
|
dial_rels = [a for a in annotations if a['id'] == dial_id][0]['relations'] |
|
|
|
|
|
turn_indexes = [[int(e.split('<')[0].strip()) for e in turn] for turn in turns] |
|
|
|
|
|
relations = format_rels(turn_indexes, dial_rels) |
|
|
|
|
|
|
|
|
|
|
|
turns_plus_relations = [] |
|
|
for i, t in enumerate(turns): |
|
|
super_turn = [] |
|
|
super_turn.append(t) |
|
|
super_turn.append(relations[i]) |
|
|
turns_plus_relations.append(super_turn) |
|
|
|
|
|
|
|
|
global_context = [] |
|
|
structure = [] |
|
|
global_context.extend(turns_plus_relations[0][0]) |
|
|
for t in turns_plus_relations[1:windows[0][1]+1]: |
|
|
sample = {} |
|
|
c = "\n".join(global_context) |
|
|
n = "\n".join(t[0]) |
|
|
|
|
|
|
|
|
rels_list = [r[2] for r in t[1]] |
|
|
r = ' '.join(rels_list) |
|
|
s = ' '.join(structure) |
|
|
|
|
|
sample['PS'] = r |
|
|
sample['sample'] = 'Context: ' + c + '\nStructure: ' + s + '\nNew Turn: ' + n |
|
|
json_l.append(sample) |
|
|
|
|
|
global_context.extend(t[0]) |
|
|
structure.extend(rels_list) |
|
|
|
|
|
|
|
|
for window in windows[1:]: |
|
|
|
|
|
|
|
|
min_x = min([int(t.split('<')[0].strip()) for t in turns_plus_relations[window[0]][0]]) |
|
|
|
|
|
global_context = [] |
|
|
structure = [] |
|
|
for tw in turns_plus_relations[window[0]:window[1]]: |
|
|
global_context.extend(tw[0]) |
|
|
|
|
|
structure.extend([rel[2] for rel in tw[1] if rel[0] >= min_x]) |
|
|
|
|
|
sample = {} |
|
|
c = "\n".join(global_context) |
|
|
n = "\n".join(turns_plus_relations[window[1]][0]) |
|
|
|
|
|
rels_list = [r[2] for r in turns_plus_relations[window[1]][1] if r[0] >= min_x] |
|
|
r = ' '.join(rels_list) |
|
|
s = ' '.join(structure) |
|
|
|
|
|
sample['PS'] = r |
|
|
sample['sample'] = 'Context: ' + c + '\nStructure: ' + s + '\nNew Turn: ' + n |
|
|
json_l.append(sample) |
|
|
|
|
|
|
|
|
|
|
|
with jsonlines.open(save_path, mode='w') as writer: |
|
|
for x in json_l: |
|
|
writer.write(x) |
|
|
|
|
|
print('jsonl saved for {} games'.format(dialogue_count)) |
|
|
|
|
|
|