Indic-POS-tagger / predict.py
roymukund's picture
Upload 3 files
ec86c24 verified
raw
history blame
2.38 kB
import _pickle as cPickle
from utils import plain_to_conll, conll_to_output, ssf_to_conll, conll_to_ssf
import os
def word_features(sent, i):
word = sent[i][0]
if i == 0: prevword = '<START>'
else: prevword = sent[i - 1][0]
if i <= 1: prev2word = '<START>'
else: prev2word = sent[i - 2][0]
if i == len(sent) - 1: nextword = '<END>'
else: nextword = sent[i + 1][0]
return {
'word': word,
'prevword': prevword,
'nextword': nextword,
'suff_1': word[-1:], 'suff_2': word[-2:], 'suff_3': word[-3:], 'suff_4': word[-4:],
'pref_1': word[:1], 'pref_2': word[:2], 'pref_3': word[:3], 'pref_4': word[:4],
'prev2word': prev2word
}
def sent2features(sent):
return [word_features(sent, i) for i in range(len(sent))]
def load_and_predict(input_file, model, output_file):
with open(model, 'rb') as fid:
crf = cPickle.load(fid)
test_data = []
with open(input_file, encoding="utf8") as fr:
temp = []
for line in fr:
line = line.strip()
if line != "":
chunk = (line.split("\t")[0], '')
temp.append(chunk)
else:
if temp:
test_data.append(temp)
temp = []
X_test1 = [sent2features(s) for s in test_data]
y_pred1 = crf.predict(X_test1)
with open(output_file, 'w', encoding="utf-8") as f:
for i in range(len(test_data)):
for j in range(len(test_data[i])):
f.write(test_data[i][j][0] + "\t" + y_pred1[i][j] + "\n")
f.write("\n")
return output_file
def predict(input_file, model, file_type, output_file="output.txt"):
temp_conll = "temp_input.conll"
tagged_conll = "tagged_output.conll"
if file_type == "plain":
plain_to_conll(input_file, temp_conll)
load_and_predict(temp_conll, model, tagged_conll)
conll_to_output(tagged_conll, output_file)
elif file_type == "ssf":
ssf_to_conll(input_file, temp_conll)
load_and_predict(temp_conll, model, tagged_conll)
conll_to_ssf(tagged_conll, input_file, output_file)
else:
load_and_predict(input_file, model, tagged_conll)
os.replace(tagged_conll, output_file)
return output_file