File size: 2,337 Bytes
685c2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from lingua.concept.linguagraph.standard import LinguaGraphNodePoses

from lingua.structure.gpgraph import GPGraph, GPGAuxNode, GPGPhraseNode, GPGNode

priority_map = {

    "AttributePredicator": 80,
    "AppositionalPredicator": 80,
    "FactualPredicator": 80,
    "LogicalPredicator": 80,
    "ReferentialPredicator": 80,
    "ConjunctionalFunctor": 80,
    "ExpressionFunctor": 40,
    "GeneralFunctor": 40,
    "ListFunctor": 40,
    "ModificationalFunctor": 80,
    "DeterminerConstant": 80,
    "InterjectionConstant": 80              ,
    "ModificationalConstant": 80,
    "NominalConstant": 80,
    "OtherConstant": 80,
    "PunctuationalConstant": 80,
    "SymbolConstant": 80
}


def min_node_distance(n1: GPGNode, n2: GPGNode):
    if isinstance(n1, GPGAuxNode) or isinstance(n2, GPGAuxNode):
        return 10000

    assert isinstance(n1, GPGPhraseNode) and isinstance(n2, GPGPhraseNode)

    import numpy as np

    n1_words = np.array(list(n1.words(with_aux=False)))
    n2_words = np.array(list(n2.words(with_aux=False)))

    from scipy.spatial.distance import cdist

    return np.min(cdist(n1_words[..., np.newaxis], n2_words[..., np.newaxis])) # word index difference between n1 and n2



def get_modification_priority(graph: GPGraph, node, modification_phrases):

    base_priority = 20

    node_words = list(node.words(with_aux=False))

    min_node_pos = min(node_words)
    
    phrase_scores = []
    for phrase in modification_phrases:
        min_modification_pos = 10000
        min_dist = 10000
        for c in phrase:
            min_dist = min(min_dist, min_node_distance(node, c))
            if isinstance(c, GPGPhraseNode):
                c_words = list(c.words(with_aux=False))
                min_modification_pos = min(min_modification_pos, min(c_words))
            # 如果修饰词在中心词之前,给予奖励;否则不奖励
            positional_bonus = 100 if min_modification_pos < min_node_pos else 0
            # 如果距离超过1,每超过1个 token 扣除一定分数
            distance_penalty_factor = 5
            distance_penalty = (min_dist - 1) * distance_penalty_factor if min_dist > 1 else 0

            effective_score = base_priority + positional_bonus - distance_penalty
        phrase_scores.append(effective_score)
        
    return phrase_scores