Toadoum commited on
Commit
88f74f5
·
verified ·
1 Parent(s): c447b1f

Create nlu.py

Browse files
Files changed (1) hide show
  1. nlu.py +48 -0
nlu.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from copy import deepcopy
4
+ from tqdm import tqdm
5
+
6
+ def load_json(path):
7
+ return json.load(open(path, 'r'))
8
+
9
+ def save_json(data, path):
10
+ return json.dump(data, open(path, 'w'))
11
+
12
+ def apply_patch_on_split(split_data, split_patch):
13
+ patched_split = []
14
+ for instance in tqdm(split_data):
15
+ instance_id = instance['id']
16
+ if instance_id in split_patch.keys():
17
+ new_instance = deepcopy(instance)
18
+ new_instance['relation'] = split_patch[instance_id]
19
+ patched_split.append(new_instance)
20
+ return patched_split
21
+
22
+ def apply_patches(split_dir, patch_dir, split_names=['train', 'dev', 'test']):
23
+ patches = {name: None for name in split_names}
24
+ for name in split_names:
25
+ # Set paths
26
+ split_path = os.path.join(split_dir, name + '.json')
27
+ patch_path = os.path.join(patch_dir, name + '_id2label.json')
28
+ # Load data
29
+ split_data = load_json(split_path)
30
+ patch_id2label = load_json(patch_path)
31
+ patch = apply_patch_on_split(split_data, patch_id2label)
32
+ patches[name] = patch
33
+ return patches
34
+
35
+ def save_patches(patch2data, save_dir):
36
+ for patch_name, data in patch2data.items():
37
+ save_path = os.path.join(save_dir, patch_name + '.json')
38
+ save_json(data, save_path)
39
+
40
+ if __name__ == '__main__':
41
+ tacred_dir = None # Directory where TACRED is stored
42
+ patch_dir = os.getcwd() # Directory where patches are located
43
+ save_dir = None # Directory where patched data should be saved
44
+ os.makedirs(save_dir, exist_ok=True)
45
+ # Apply patches on data splits
46
+ patch2data = apply_patches(split_dir=tacred_dir, patch_dir=patch_dir)
47
+ # Save patched data to desired directory
48
+ save_patches(patch2data=patch2data, save_dir=save_dir)