Spaces:
Sleeping
Sleeping
| import _pickle as cPickle | |
| from utils import plain_to_conll, conll_to_output, ssf_to_conll, conll_to_ssf | |
| import os | |
| def word_features(sent, i): | |
| word = sent[i][0] | |
| if i == 0: prevword = '<START>' | |
| else: prevword = sent[i - 1][0] | |
| if i <= 1: prev2word = '<START>' | |
| else: prev2word = sent[i - 2][0] | |
| if i == len(sent) - 1: nextword = '<END>' | |
| else: nextword = sent[i + 1][0] | |
| return { | |
| 'word': word, | |
| 'prevword': prevword, | |
| 'nextword': nextword, | |
| 'suff_1': word[-1:], 'suff_2': word[-2:], 'suff_3': word[-3:], 'suff_4': word[-4:], | |
| 'pref_1': word[:1], 'pref_2': word[:2], 'pref_3': word[:3], 'pref_4': word[:4], | |
| 'prev2word': prev2word | |
| } | |
| def sent2features(sent): | |
| return [word_features(sent, i) for i in range(len(sent))] | |
| def load_and_predict(input_file, model, output_file): | |
| with open(model, 'rb') as fid: | |
| crf = cPickle.load(fid) | |
| test_data = [] | |
| with open(input_file, encoding="utf8") as fr: | |
| temp = [] | |
| for line in fr: | |
| line = line.strip() | |
| if line != "": | |
| chunk = (line.split("\t")[0], '') | |
| temp.append(chunk) | |
| else: | |
| if temp: | |
| test_data.append(temp) | |
| temp = [] | |
| X_test1 = [sent2features(s) for s in test_data] | |
| y_pred1 = crf.predict(X_test1) | |
| with open(output_file, 'w', encoding="utf-8") as f: | |
| for i in range(len(test_data)): | |
| for j in range(len(test_data[i])): | |
| f.write(test_data[i][j][0] + "\t" + y_pred1[i][j] + "\n") | |
| f.write("\n") | |
| return output_file | |
| def predict(input_file, model, file_type, output_file="output.txt"): | |
| temp_conll = "temp_input.conll" | |
| tagged_conll = "tagged_output.conll" | |
| if file_type == "plain": | |
| plain_to_conll(input_file, temp_conll) | |
| load_and_predict(temp_conll, model, tagged_conll) | |
| conll_to_output(tagged_conll, output_file) | |
| elif file_type == "ssf": | |
| ssf_to_conll(input_file, temp_conll) | |
| load_and_predict(temp_conll, model, tagged_conll) | |
| conll_to_ssf(tagged_conll, input_file, output_file) | |
| else: | |
| load_and_predict(input_file, model, tagged_conll) | |
| os.replace(tagged_conll, output_file) | |
| return output_file | |