lynn-twinkl
commited on
Commit
·
2204e3e
1
Parent(s):
d2199ff
Functions to begin training, generate predictions, and remove non-context labels
Browse files
ner-training/begin-training.zsh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/zsh
|
| 2 |
+
|
| 3 |
+
source ner_venv/bin/activate
|
| 4 |
+
|
| 5 |
+
train_spacy_file=$1
|
| 6 |
+
dev_spacy_file=$2
|
| 7 |
+
model_outdir=$3
|
| 8 |
+
|
| 9 |
+
python3 -m spacy train transformer.cfg \
|
| 10 |
+
--paths.train "$train_spacy_file" \
|
| 11 |
+
--paths.dev "$dev_spacy_file" \
|
| 12 |
+
--gpu-id 0 \
|
| 13 |
+
--output "$model_outdir"
|
| 14 |
+
|
ner-training/predict.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spacy
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
csv_path = sys.argv[1]
|
| 6 |
+
custom_model_path = sys.argv[2]
|
| 7 |
+
|
| 8 |
+
df = pd.read_csv(csv_path)
|
| 9 |
+
texts = df['Additional Info'].to_list()
|
| 10 |
+
|
| 11 |
+
trained_nlp = spacy.load(custom_model_path)
|
| 12 |
+
|
| 13 |
+
for text in texts:
|
| 14 |
+
doc = trained_nlp(text)
|
| 15 |
+
print(f"TEXT: {text}")
|
| 16 |
+
print()
|
| 17 |
+
print("ENTITIES:", [(ent.text, ent.label_) for ent in doc.ents])
|
| 18 |
+
print('-'*60)
|
ner-training/remove_non_context_labels.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
file_to_filter = sys.argv[1]
|
| 5 |
+
|
| 6 |
+
with open(file_to_filter, 'r') as input_file:
|
| 7 |
+
dataset = json.load(input_file)
|
| 8 |
+
|
| 9 |
+
def filter_context_labels(dataset):
|
| 10 |
+
for item in dataset:
|
| 11 |
+
item['label'] = [l for l in item['label'] if 'Context' in l['labels']]
|
| 12 |
+
return dataset
|
| 13 |
+
|
| 14 |
+
filtered_data = filter_context_labels(dataset)
|
| 15 |
+
|
| 16 |
+
with open('context-only-labels.json', 'w') as output_file:
|
| 17 |
+
json.dump(filtered_data, output_file, indent=2)
|