Spaces:
Sleeping
Sleeping
File size: 2,377 Bytes
ec86c24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import _pickle as cPickle
from utils import plain_to_conll, conll_to_output, ssf_to_conll, conll_to_ssf
import os
def word_features(sent, i):
word = sent[i][0]
if i == 0: prevword = '<START>'
else: prevword = sent[i - 1][0]
if i <= 1: prev2word = '<START>'
else: prev2word = sent[i - 2][0]
if i == len(sent) - 1: nextword = '<END>'
else: nextword = sent[i + 1][0]
return {
'word': word,
'prevword': prevword,
'nextword': nextword,
'suff_1': word[-1:], 'suff_2': word[-2:], 'suff_3': word[-3:], 'suff_4': word[-4:],
'pref_1': word[:1], 'pref_2': word[:2], 'pref_3': word[:3], 'pref_4': word[:4],
'prev2word': prev2word
}
def sent2features(sent):
return [word_features(sent, i) for i in range(len(sent))]
def load_and_predict(input_file, model, output_file):
with open(model, 'rb') as fid:
crf = cPickle.load(fid)
test_data = []
with open(input_file, encoding="utf8") as fr:
temp = []
for line in fr:
line = line.strip()
if line != "":
chunk = (line.split("\t")[0], '')
temp.append(chunk)
else:
if temp:
test_data.append(temp)
temp = []
X_test1 = [sent2features(s) for s in test_data]
y_pred1 = crf.predict(X_test1)
with open(output_file, 'w', encoding="utf-8") as f:
for i in range(len(test_data)):
for j in range(len(test_data[i])):
f.write(test_data[i][j][0] + "\t" + y_pred1[i][j] + "\n")
f.write("\n")
return output_file
def predict(input_file, model, file_type, output_file="output.txt"):
temp_conll = "temp_input.conll"
tagged_conll = "tagged_output.conll"
if file_type == "plain":
plain_to_conll(input_file, temp_conll)
load_and_predict(temp_conll, model, tagged_conll)
conll_to_output(tagged_conll, output_file)
elif file_type == "ssf":
ssf_to_conll(input_file, temp_conll)
load_and_predict(temp_conll, model, tagged_conll)
conll_to_ssf(tagged_conll, input_file, output_file)
else:
load_and_predict(input_file, model, tagged_conll)
os.replace(tagged_conll, output_file)
return output_file
|