File size: 5,499 Bytes
d2ff6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import numpy


class ScoreParam:
    def __init__(self, match, mismatch, gap_open, gap_extend):
        self.match = match
        self.mismatch = mismatch
        self.gap_open = gap_open
        self.gap_extend = gap_extend

    def __str__(self):
        return f"Match: {self.match}, Mismatch: {self.mismatch}, Gap Open: {self.gap_open}, Gap Extend: {self.gap_extend}"


class SeqGraphAlignment(object):
    __default_score = ScoreParam(1, -3, -2, -1)

    def __init__(
        self,
        sequence,
        graph,
        fastMethod=True,
        globalAlign=False,
        score_params=__default_score,
        *args,
        **kwargs,
    ):
        self.score = score_params
        self.sequence = sequence
        self.graph = graph
        self.stringidxs = None
        self.nodeidxs = None
        self.globalAlign = globalAlign
        if fastMethod:
            matches = self.alignStringToGraphFast(*args, **kwargs)
        else:
            matches = self.alignStringToGraphSimple(*args, **kwargs)
        self.stringidxs, self.nodeidxs = matches

    def alignmentStrings(self):
        return (
            "".join(self.sequence[i] if i is not None else "-" for i in self.stringidxs),
            "".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs),
        )

    def matchscore(self, c1, c2):
        if c1 == c2:
            return self.score.match
        else:
            return self.score.mismatch

    def matchscoreVec(self, c, v):
        return numpy.where(v == c, self.score.match, self.score.mismatch)

    def prevIndices(self, node, nodeIDtoIndex):
        prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
        if not prev:
            prev = [-1]
        return prev

    def initializeDynamicProgrammingData(self):
        l1 = self.graph.nNodes
        l2 = len(self.sequence)

        nodeIDtoIndex = {}
        nodeIndexToID = {-1: None}
        ni = self.graph.nodeiterator()
        for index, node in enumerate(ni()):
            nodeIDtoIndex[node.ID] = index
            nodeIndexToID[index] = node.ID

        scores = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)

        if self.globalAlign:
            # M[0, i] = -inf
            scores[0, 0, :] = [
                -1000000000 for i in range(l2+1)
            ]
            scores[0, 0, 0] = 0
            # X[0, i] = gap_open + i * gap_extend
            scores[1, 0, :] = [
                self.score.gap_open + i * self.score.gap_extend for i in range(l2 + 1) 
            ]
            scores[1, 0, 0] = -1000000000
            # Y[0, i] = -inf
            scores[2, 0, :] = [
                -1000000000 for i in range(l2+1)
            ]

            ni = self.graph.nodeiterator()
            # After topology sort, the predcessors will have index less than the current node
            for index, node in enumerate(ni()):
                scores[0, index + 1, 0] = -1000000000
                scores[1, index + 1, 0] = -1000000000
                prevIdxs = self.prevIndices(node, nodeIDtoIndex)
                best = scores[2 ,prevIdxs[0] + 1, 0]
                for prevIdx in prevIdxs:
                    best = max(best, scores[2, prevIdx + 1, 0])
                # If we have no predecessors, we start the gap 
                if prevIdxs == [-1]:
                    scores[2, index + 1, 0] =  self.score.gap_open + self.score.gap_extend
                else:
                    scores[2, index + 1, 0] = best + self.score.gap_extend

        # 3D Backtracking
        backStrIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
        backGrphIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
        backMtxIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)

        return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx

    def backtrack(self, scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID):
        besti, bestj = scores.shape[1] - 1, scores.shape[2] - 1
        #Storing best matrices for each [i,j]
        scores_arr = numpy.array(scores)
        max_m = numpy.argmax(scores_arr, axis=0)
        
        if self.globalAlign:
            ni = self.graph.nodeiterator()
            # Finding the best node to start from
            terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
            print(terminalIndices)
            besti = terminalIndices[0] + 1
            bestscore = scores[max_m[besti, bestj], besti, bestj]
            for i in terminalIndices[1:]:
                score = scores[max_m[i + 1, bestj], i + 1, bestj]
                if score > bestscore:
                    bestscore, besti = score, i + 1
            bestm = max_m[besti, bestj]

        matches = []
        strindexes = []

        while (besti != 0 or bestj != 0):
            nextm, nexti, nextj,  = backMtxIdx[bestm, besti, bestj], backGrphIdx[bestm, besti, bestj], backStrIdx[bestm, besti, bestj]
            curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
            
            if bestm == 0:
                matches.insert(0, curnodeidx)
                strindexes.insert(0, curstridx)
            elif bestm == 1:
                matches.insert(0, None)
                strindexes.insert(0, curstridx)
            else:
                matches.insert(0, curnodeidx)
                strindexes.insert(0, None)

            bestm, besti, bestj = nextm, nexti, nextj

        return strindexes, matches