lynn-twinkl commited on
Commit
2204e3e
·
1 Parent(s): d2199ff

Functions to begin training, generate predictions, and remove non-context labels

Browse files
ner-training/begin-training.zsh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/zsh
2
+
3
+ source ner_venv/bin/activate
4
+
5
+ train_spacy_file=$1
6
+ dev_spacy_file=$2
7
+ model_outdir=$3
8
+
9
+ python3 -m spacy train transformer.cfg \
10
+ --paths.train "$train_spacy_file" \
11
+ --paths.dev "$dev_spacy_file" \
12
+ --gpu-id 0 \
13
+ --output "$model_outdir"
14
+
ner-training/predict.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ import pandas as pd
3
+ import sys
4
+
5
+ csv_path = sys.argv[1]
6
+ custom_model_path = sys.argv[2]
7
+
8
+ df = pd.read_csv(csv_path)
9
+ texts = df['Additional Info'].to_list()
10
+
11
+ trained_nlp = spacy.load(custom_model_path)
12
+
13
+ for text in texts:
14
+ doc = trained_nlp(text)
15
+ print(f"TEXT: {text}")
16
+ print()
17
+ print("ENTITIES:", [(ent.text, ent.label_) for ent in doc.ents])
18
+ print('-'*60)
ner-training/remove_non_context_labels.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+
4
+ file_to_filter = sys.argv[1]
5
+
6
+ with open(file_to_filter, 'r') as input_file:
7
+ dataset = json.load(input_file)
8
+
9
+ def filter_context_labels(dataset):
10
+ for item in dataset:
11
+ item['label'] = [l for l in item['label'] if 'Context' in l['labels']]
12
+ return dataset
13
+
14
+ filtered_data = filter_context_labels(dataset)
15
+
16
+ with open('context-only-labels.json', 'w') as output_file:
17
+ json.dump(filtered_data, output_file, indent=2)