import numpy class ScoreParam: def __init__(self, match, mismatch, gap_open, gap_extend): self.match = match self.mismatch = mismatch self.gap_open = gap_open self.gap_extend = gap_extend def __str__(self): return f"Match: {self.match}, Mismatch: {self.mismatch}, Gap Open: {self.gap_open}, Gap Extend: {self.gap_extend}" class SeqGraphAlignment(object): __default_score = ScoreParam(1, -3, -2, -1) def __init__( self, sequence, graph, fastMethod=True, globalAlign=False, score_params=__default_score, *args, **kwargs, ): self.score = score_params self.sequence = sequence self.graph = graph self.stringidxs = None self.nodeidxs = None self.globalAlign = globalAlign if fastMethod: matches = self.alignStringToGraphFast(*args, **kwargs) else: matches = self.alignStringToGraphSimple(*args, **kwargs) self.stringidxs, self.nodeidxs = matches def alignmentStrings(self): return ( "".join(self.sequence[i] if i is not None else "-" for i in self.stringidxs), "".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs), ) def matchscore(self, c1, c2): if c1 == c2: return self.score.match else: return self.score.mismatch def matchscoreVec(self, c, v): return numpy.where(v == c, self.score.match, self.score.mismatch) def prevIndices(self, node, nodeIDtoIndex): prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())] if not prev: prev = [-1] return prev def initializeDynamicProgrammingData(self): l1 = self.graph.nNodes l2 = len(self.sequence) nodeIDtoIndex = {} nodeIndexToID = {-1: None} ni = self.graph.nodeiterator() for index, node in enumerate(ni()): nodeIDtoIndex[node.ID] = index nodeIndexToID[index] = node.ID scores = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32) if self.globalAlign: # M[0, i] = -inf scores[0, 0, :] = [ -1000000000 for i in range(l2+1) ] scores[0, 0, 0] = 0 # X[0, i] = gap_open + i * gap_extend scores[1, 0, :] = [ self.score.gap_open + i * self.score.gap_extend for i in range(l2 + 1) ] scores[1, 0, 0] = -1000000000 # Y[0, i] = -inf scores[2, 0, :] = [ -1000000000 for i in range(l2+1) ] ni = self.graph.nodeiterator() # After topology sort, the predcessors will have index less than the current node for index, node in enumerate(ni()): scores[0, index + 1, 0] = -1000000000 scores[1, index + 1, 0] = -1000000000 prevIdxs = self.prevIndices(node, nodeIDtoIndex) best = scores[2 ,prevIdxs[0] + 1, 0] for prevIdx in prevIdxs: best = max(best, scores[2, prevIdx + 1, 0]) # If we have no predecessors, we start the gap if prevIdxs == [-1]: scores[2, index + 1, 0] = self.score.gap_open + self.score.gap_extend else: scores[2, index + 1, 0] = best + self.score.gap_extend # 3D Backtracking backStrIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32) backGrphIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32) backMtxIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32) return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx def backtrack(self, scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID): besti, bestj = scores.shape[1] - 1, scores.shape[2] - 1 #Storing best matrices for each [i,j] scores_arr = numpy.array(scores) max_m = numpy.argmax(scores_arr, axis=0) if self.globalAlign: ni = self.graph.nodeiterator() # Finding the best node to start from terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0] print(terminalIndices) besti = terminalIndices[0] + 1 bestscore = scores[max_m[besti, bestj], besti, bestj] for i in terminalIndices[1:]: score = scores[max_m[i + 1, bestj], i + 1, bestj] if score > bestscore: bestscore, besti = score, i + 1 bestm = max_m[besti, bestj] matches = [] strindexes = [] while (besti != 0 or bestj != 0): nextm, nexti, nextj, = backMtxIdx[bestm, besti, bestj], backGrphIdx[bestm, besti, bestj], backStrIdx[bestm, besti, bestj] curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1] if bestm == 0: matches.insert(0, curnodeidx) strindexes.insert(0, curstridx) elif bestm == 1: matches.insert(0, None) strindexes.insert(0, curstridx) else: matches.insert(0, curnodeidx) strindexes.insert(0, None) bestm, besti, bestj = nextm, nexti, nextj return strindexes, matches