File size: 2,377 Bytes
ec86c24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import _pickle as cPickle
from utils import plain_to_conll, conll_to_output, ssf_to_conll, conll_to_ssf
import os

def word_features(sent, i):
    word = sent[i][0]
    if i == 0: prevword = '<START>'
    else: prevword = sent[i - 1][0]
    if i <= 1: prev2word = '<START>'
    else: prev2word = sent[i - 2][0]
    if i == len(sent) - 1: nextword = '<END>'
    else: nextword = sent[i + 1][0]

    return {
        'word': word,
        'prevword': prevword,
        'nextword': nextword,
        'suff_1': word[-1:], 'suff_2': word[-2:], 'suff_3': word[-3:], 'suff_4': word[-4:],
        'pref_1': word[:1], 'pref_2': word[:2], 'pref_3': word[:3], 'pref_4': word[:4],
        'prev2word': prev2word
    }

def sent2features(sent):
    return [word_features(sent, i) for i in range(len(sent))]

def load_and_predict(input_file, model, output_file):
    with open(model, 'rb') as fid:
        crf = cPickle.load(fid)

    test_data = []
    with open(input_file, encoding="utf8") as fr:
        temp = []
        for line in fr:
            line = line.strip()
            if line != "":
                chunk = (line.split("\t")[0], '')
                temp.append(chunk)
            else:
                if temp:
                    test_data.append(temp)
                temp = []

    X_test1 = [sent2features(s) for s in test_data]
    y_pred1 = crf.predict(X_test1)

    with open(output_file, 'w', encoding="utf-8") as f:
        for i in range(len(test_data)):
            for j in range(len(test_data[i])):
                f.write(test_data[i][j][0] + "\t" + y_pred1[i][j] + "\n")
            f.write("\n")
    return output_file


def predict(input_file, model, file_type, output_file="output.txt"):
    temp_conll = "temp_input.conll"
    tagged_conll = "tagged_output.conll"

    if file_type == "plain":
        plain_to_conll(input_file, temp_conll)
        load_and_predict(temp_conll, model, tagged_conll)
        conll_to_output(tagged_conll, output_file)
    elif file_type == "ssf":
        ssf_to_conll(input_file, temp_conll)
        load_and_predict(temp_conll, model, tagged_conll)
        conll_to_ssf(tagged_conll, input_file, output_file)    
    else:
        load_and_predict(input_file, model, tagged_conll)
        os.replace(tagged_conll, output_file)

    return output_file