Spaces:
Build error
Build error
Upload 11 files
Browse files- src/__init__.py +0 -0
- src/alignment.py +256 -0
- src/generation_methods.py +299 -0
- src/generation_utils.py +190 -0
- src/global_edit_utils.py +127 -0
- src/new_alignment.py +150 -0
- src/new_text_alignment.py +134 -0
- src/poa_graph.py +685 -0
- src/text_poa_graph.py +802 -0
- src/text_poa_graph_utils.py +126 -0
- src/utils.py +46 -0
src/__init__.py
ADDED
|
File without changes
|
src/alignment.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from Jonathan Dursi
|
| 3 |
+
https://github.com/ljdursi/poapy
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SeqGraphAlignment(object):
|
| 10 |
+
__matchscore = 1
|
| 11 |
+
__mismatchscore = -2
|
| 12 |
+
__gap = -1
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
sequence,
|
| 17 |
+
graph,
|
| 18 |
+
fastMethod=True,
|
| 19 |
+
globalAlign=False,
|
| 20 |
+
matchscore=__matchscore,
|
| 21 |
+
mismatchscore=__mismatchscore,
|
| 22 |
+
gapscore=__gap,
|
| 23 |
+
*args,
|
| 24 |
+
**kwargs,
|
| 25 |
+
):
|
| 26 |
+
self._mismatchscore = mismatchscore
|
| 27 |
+
self._matchscore = matchscore
|
| 28 |
+
self._gap = gapscore
|
| 29 |
+
self.sequence = sequence
|
| 30 |
+
self.graph = graph
|
| 31 |
+
self.stringidxs = None
|
| 32 |
+
self.nodeidxs = None
|
| 33 |
+
self.globalAlign = globalAlign
|
| 34 |
+
if fastMethod:
|
| 35 |
+
matches = self.alignStringToGraphFast(*args, **kwargs)
|
| 36 |
+
else:
|
| 37 |
+
matches = self.alignStringToGraphSimple(*args, **kwargs)
|
| 38 |
+
self.stringidxs, self.nodeidxs = matches
|
| 39 |
+
|
| 40 |
+
def alignmentStrings(self):
|
| 41 |
+
return "".join(
|
| 42 |
+
self.sequence[i] if i is not None else "-" for i in self.stringidxs
|
| 43 |
+
), "".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs)
|
| 44 |
+
|
| 45 |
+
def matchscore(self, c1, c2):
|
| 46 |
+
if c1 == c2:
|
| 47 |
+
return self._matchscore
|
| 48 |
+
else:
|
| 49 |
+
return self._mismatchscore
|
| 50 |
+
|
| 51 |
+
def matchscoreVec(self, c, v):
|
| 52 |
+
return numpy.where(v == c, self._matchscore, self._mismatchscore)
|
| 53 |
+
|
| 54 |
+
def alignStringToGraphSimple(self):
|
| 55 |
+
"""Align string to graph, following same approach as smith waterman
|
| 56 |
+
example"""
|
| 57 |
+
if type(self.sequence) is not str:
|
| 58 |
+
raise TypeError("Invalid Type")
|
| 59 |
+
|
| 60 |
+
nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx = (
|
| 61 |
+
self.initializeDynamicProgrammingData()
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Dynamic Programming
|
| 65 |
+
ni = self.graph.nodeiterator()
|
| 66 |
+
for i, node in enumerate(ni()):
|
| 67 |
+
pbase = node.text
|
| 68 |
+
|
| 69 |
+
for j, sbase in enumerate(self.sequence):
|
| 70 |
+
# add all candidates to a list, pick the best
|
| 71 |
+
candidates = [(scores[i + 1, j] + self._gap, i + 1, j, "INS")]
|
| 72 |
+
for predIndex in self.prevIndices(node, nodeIDtoIndex):
|
| 73 |
+
candidates += [
|
| 74 |
+
(scores[predIndex + 1, j + 1] + self._gap, predIndex + 1, j + 1, "DEL")
|
| 75 |
+
]
|
| 76 |
+
candidates += [
|
| 77 |
+
(
|
| 78 |
+
scores[predIndex + 1, j] + self.matchscore(sbase, pbase),
|
| 79 |
+
predIndex + 1,
|
| 80 |
+
j,
|
| 81 |
+
"MATCH",
|
| 82 |
+
)
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
(
|
| 86 |
+
scores[i + 1, j + 1],
|
| 87 |
+
backGrphIdx[i + 1, j + 1],
|
| 88 |
+
backStrIdx[i + 1, j + 1],
|
| 89 |
+
movetype,
|
| 90 |
+
) = max(candidates)
|
| 91 |
+
|
| 92 |
+
if not self.globalAlign and scores[i + 1, j + 1] < 0:
|
| 93 |
+
scores[i + 1, j + 1] = 0.0
|
| 94 |
+
backGrphIdx[i + 1, j + 1] = -1
|
| 95 |
+
backStrIdx[i + 1, j + 1] = -1
|
| 96 |
+
|
| 97 |
+
return self.backtrack(scores, backStrIdx, backGrphIdx, nodeIndexToID)
|
| 98 |
+
|
| 99 |
+
def alignStringToGraphFast(self):
|
| 100 |
+
"""Align string to graph - using numpy to vectorize across the string
|
| 101 |
+
at each iteration."""
|
| 102 |
+
if type(self.sequence) is not str:
|
| 103 |
+
raise TypeError("Invalid Type")
|
| 104 |
+
|
| 105 |
+
l2 = len(self.sequence)
|
| 106 |
+
seqvec = numpy.array(list(self.sequence))
|
| 107 |
+
|
| 108 |
+
nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx = (
|
| 109 |
+
self.initializeDynamicProgrammingData()
|
| 110 |
+
)
|
| 111 |
+
inserted = numpy.zeros((l2), dtype=bool)
|
| 112 |
+
|
| 113 |
+
# having the inner loop as a function improves performance
|
| 114 |
+
# can use Cython, etc on this for significant further improvements
|
| 115 |
+
# can't vectorize this since there's a loop-carried dependency
|
| 116 |
+
# along the string
|
| 117 |
+
def insertions(i, l2, scores, inserted):
|
| 118 |
+
inserted[:] = False
|
| 119 |
+
for j in range(l2):
|
| 120 |
+
insscore = scores[i + 1, j] + self._gap
|
| 121 |
+
if insscore >= scores[i + 1, j + 1]:
|
| 122 |
+
scores[i + 1, j + 1] = insscore
|
| 123 |
+
inserted[j] = True
|
| 124 |
+
|
| 125 |
+
# Dynamic Programming
|
| 126 |
+
ni = self.graph.nodeiterator()
|
| 127 |
+
for i, node in enumerate(ni()):
|
| 128 |
+
gbase = node.text
|
| 129 |
+
predecessors = self.prevIndices(node, nodeIDtoIndex)
|
| 130 |
+
|
| 131 |
+
# calculate all best deletions, matches in one go over all
|
| 132 |
+
# predecessors.
|
| 133 |
+
|
| 134 |
+
# First calculate for the first predecessor, over all string posns:
|
| 135 |
+
deletescore = scores[predecessors[0] + 1, 1:] + self._gap
|
| 136 |
+
bestdelete = numpy.zeros((l2), dtype=numpy.int32) + predecessors[0] + 1
|
| 137 |
+
|
| 138 |
+
matchpoints = self.matchscoreVec(gbase, seqvec)
|
| 139 |
+
matchscore = scores[predecessors[0] + 1, 0:-1] + matchpoints
|
| 140 |
+
bestmatch = numpy.zeros((l2), dtype=numpy.int32) + predecessors[0] + 1
|
| 141 |
+
|
| 142 |
+
# then, the remaining
|
| 143 |
+
for predecessor in predecessors[1:]:
|
| 144 |
+
newdeletescore = scores[predecessor + 1, 1:] + self._gap
|
| 145 |
+
bestdelete = numpy.where(newdeletescore > deletescore, predecessor + 1, bestdelete)
|
| 146 |
+
deletescore = numpy.maximum(newdeletescore, deletescore)
|
| 147 |
+
|
| 148 |
+
gbase = self.graph.nodeIdxToBase(predecessor)
|
| 149 |
+
matchpoints = self.matchscoreVec(gbase, seqvec)
|
| 150 |
+
newmatchscore = scores[predecessor + 1, 0:-1] + matchpoints
|
| 151 |
+
bestmatch = numpy.where(newmatchscore > matchscore, predecessor + 1, bestmatch)
|
| 152 |
+
matchscore = numpy.maximum(newmatchscore, matchscore)
|
| 153 |
+
|
| 154 |
+
# choose best options available of match, delete
|
| 155 |
+
deleted = deletescore >= matchscore
|
| 156 |
+
backGrphIdx[i + 1, 1:] = numpy.where(deleted, bestdelete, bestmatch)
|
| 157 |
+
backStrIdx[i + 1, 1:] = numpy.where(
|
| 158 |
+
deleted, numpy.arange(1, l2 + 1), numpy.arange(0, l2)
|
| 159 |
+
)
|
| 160 |
+
scores[i + 1, 1:] = numpy.where(deleted, deletescore, matchscore)
|
| 161 |
+
|
| 162 |
+
# insertions: updated in place, don't depend on predecessors
|
| 163 |
+
insertions(i, l2, scores, inserted)
|
| 164 |
+
backGrphIdx[i + 1, 1:] = numpy.where(inserted, i + 1, backGrphIdx[i + 1, 1:])
|
| 165 |
+
backStrIdx[i + 1, 1:] = numpy.where(inserted, numpy.arange(l2), backStrIdx[i + 1, 1:])
|
| 166 |
+
|
| 167 |
+
# if we're doing local alignment, don't let bad global alignment
|
| 168 |
+
# drag us negative
|
| 169 |
+
if not self.globalAlign:
|
| 170 |
+
backGrphIdx[i + 1, :] = numpy.where(scores[i + 1, :] > 0, backGrphIdx[i + 1, :], -1)
|
| 171 |
+
backStrIdx[i + 1, :] = numpy.where(scores[i + 1, :] > 0, backStrIdx[i + 1, :], -1)
|
| 172 |
+
scores[i + 1, :] = numpy.maximum(scores[i + 1, :], 0)
|
| 173 |
+
|
| 174 |
+
return self.backtrack(scores, backStrIdx, backGrphIdx, nodeIndexToID)
|
| 175 |
+
|
| 176 |
+
def prevIndices(self, node, nodeIDtoIndex):
|
| 177 |
+
"""Return a list of the previous dynamic programming table indices
|
| 178 |
+
corresponding to predecessors of the current node."""
|
| 179 |
+
prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
|
| 180 |
+
# if no predecessors, point to just before the graph
|
| 181 |
+
if not prev:
|
| 182 |
+
prev = [-1]
|
| 183 |
+
return prev
|
| 184 |
+
|
| 185 |
+
def initializeDynamicProgrammingData(self):
|
| 186 |
+
"""Initalize the dynamic programming tables:
|
| 187 |
+
- set up scores array
|
| 188 |
+
- set up backtracking array
|
| 189 |
+
- create index to Node ID table and vice versa"""
|
| 190 |
+
l1 = self.graph.nNodes
|
| 191 |
+
l2 = len(self.sequence)
|
| 192 |
+
|
| 193 |
+
nodeIDtoIndex = {}
|
| 194 |
+
nodeIndexToID = {-1: None}
|
| 195 |
+
# generate a dict of (nodeID) -> (index into nodelist (and thus matrix))
|
| 196 |
+
ni = self.graph.nodeiterator()
|
| 197 |
+
for index, node in enumerate(ni()):
|
| 198 |
+
nodeIDtoIndex[node.ID] = index
|
| 199 |
+
nodeIndexToID[index] = node.ID
|
| 200 |
+
|
| 201 |
+
# Dynamic Programming data structures; scores matrix and backtracking
|
| 202 |
+
# matrix
|
| 203 |
+
scores = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 204 |
+
|
| 205 |
+
# initialize insertion score
|
| 206 |
+
# if global align, penalty for starting at head != 0
|
| 207 |
+
if self.globalAlign:
|
| 208 |
+
scores[0, :] = numpy.arange(l2 + 1) * self._gap
|
| 209 |
+
|
| 210 |
+
ni = self.graph.nodeiterator()
|
| 211 |
+
for index, node in enumerate(ni()):
|
| 212 |
+
prevIdxs = self.prevIndices(node, nodeIDtoIndex)
|
| 213 |
+
best = scores[prevIdxs[0] + 1, 0]
|
| 214 |
+
for prevIdx in prevIdxs:
|
| 215 |
+
best = max(best, scores[prevIdx + 1, 0])
|
| 216 |
+
scores[index + 1, 0] = best + self._gap
|
| 217 |
+
|
| 218 |
+
# backtracking matrices
|
| 219 |
+
backStrIdx = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 220 |
+
backGrphIdx = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 221 |
+
|
| 222 |
+
return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx
|
| 223 |
+
|
| 224 |
+
def backtrack(self, scores, backStrIdx, backGrphIdx, nodeIndexToID):
|
| 225 |
+
"""Backtrack through the scores and backtrack arrays.
|
| 226 |
+
Return a list of sequence indices and node IDs (not indices, which
|
| 227 |
+
depend on ordering)."""
|
| 228 |
+
besti, bestj = scores.shape
|
| 229 |
+
besti -= 1
|
| 230 |
+
bestj -= 1
|
| 231 |
+
if not self.globalAlign:
|
| 232 |
+
besti, bestj = numpy.argwhere(scores == numpy.amax(scores))[-1]
|
| 233 |
+
else:
|
| 234 |
+
ni = self.graph.nodeiterator()
|
| 235 |
+
# still have to find best final index to start from
|
| 236 |
+
terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
|
| 237 |
+
print(terminalIndices)
|
| 238 |
+
besti = terminalIndices[0] + 1
|
| 239 |
+
bestscore = scores[besti, bestj]
|
| 240 |
+
for i in terminalIndices[1:]:
|
| 241 |
+
score = scores[i + 1, bestj]
|
| 242 |
+
if score > bestscore:
|
| 243 |
+
bestscore, besti = score, i + 1
|
| 244 |
+
|
| 245 |
+
matches = []
|
| 246 |
+
strindexes = []
|
| 247 |
+
while (self.globalAlign or scores[besti, bestj] > 0) and (besti != 0 or bestj != 0):
|
| 248 |
+
nexti, nextj = backGrphIdx[besti, bestj], backStrIdx[besti, bestj]
|
| 249 |
+
curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
|
| 250 |
+
|
| 251 |
+
strindexes.insert(0, curstridx if nextj != bestj else None)
|
| 252 |
+
matches.insert(0, curnodeidx if nexti != besti else None)
|
| 253 |
+
|
| 254 |
+
besti, bestj = nexti, nextj
|
| 255 |
+
|
| 256 |
+
return strindexes, matches
|
src/generation_methods.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from src.generation_utils import (
|
| 4 |
+
extract_alternative_paths,
|
| 5 |
+
extract_context,
|
| 6 |
+
extract_equivalent_classes,
|
| 7 |
+
self_complete,
|
| 8 |
+
verify_correctness_pairwise,
|
| 9 |
+
)
|
| 10 |
+
from src.global_edit_utils import clean_up_text
|
| 11 |
+
from src.text_poa_graph import TextPOAGraph
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold.
|
| 15 |
+
Only the primary variation of selected variable nodes are selected.
|
| 16 |
+
Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies).
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
text_poa_graph: The TextPOAGraph object to decode.
|
| 20 |
+
selection_threshold: The threshold for selecting nodes.
|
| 21 |
+
model: The model to use for decoding.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
A string of the decoded text.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def decode_consensus(
|
| 29 |
+
text_poa_graph: TextPOAGraph,
|
| 30 |
+
selection_threshold: Optional[float] = 0.5,
|
| 31 |
+
task: str = "bio",
|
| 32 |
+
verbose: bool = False,
|
| 33 |
+
**kwargs,
|
| 34 |
+
) -> str:
|
| 35 |
+
if text_poa_graph.failed:
|
| 36 |
+
return "Abstain"
|
| 37 |
+
|
| 38 |
+
text_poa_graph.toposort()
|
| 39 |
+
|
| 40 |
+
consensus_node_ids = text_poa_graph.consensus_node_ids
|
| 41 |
+
|
| 42 |
+
selected_node_ids = []
|
| 43 |
+
|
| 44 |
+
for node_id in consensus_node_ids:
|
| 45 |
+
if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
selected_node_ids.append(node_id)
|
| 49 |
+
|
| 50 |
+
for neighbor_id in text_poa_graph.nodedict[node_id].outEdges:
|
| 51 |
+
if neighbor_id in consensus_node_ids:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
if (
|
| 55 |
+
len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences
|
| 56 |
+
>= selection_threshold
|
| 57 |
+
):
|
| 58 |
+
selected_node_ids.append(neighbor_id)
|
| 59 |
+
|
| 60 |
+
texts = []
|
| 61 |
+
for node_id in selected_node_ids:
|
| 62 |
+
if not text_poa_graph.nodedict[node_id].variations:
|
| 63 |
+
texts.append(text_poa_graph.nodedict[node_id].text)
|
| 64 |
+
else:
|
| 65 |
+
all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()]
|
| 66 |
+
all_texts.append(text_poa_graph.nodedict[node_id].text)
|
| 67 |
+
# select the variation that is longest
|
| 68 |
+
texts.append(max(all_texts, key=len))
|
| 69 |
+
text = " ".join(texts)
|
| 70 |
+
edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs)
|
| 71 |
+
if verbose:
|
| 72 |
+
return text, edited_text
|
| 73 |
+
else:
|
| 74 |
+
return edited_text
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def decode_self_verified(
|
| 78 |
+
text_poa_graph: TextPOAGraph,
|
| 79 |
+
problem: str,
|
| 80 |
+
uncertainty_threshold: float = 0.6,
|
| 81 |
+
verification_api: str = "openai",
|
| 82 |
+
verification_model: str = "gpt-4o-mini",
|
| 83 |
+
grace_period: bool = True,
|
| 84 |
+
):
|
| 85 |
+
high_uncertainty_nodes = []
|
| 86 |
+
for node_id in text_poa_graph.consensus_node_ids:
|
| 87 |
+
if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
outgoing_edges = text_poa_graph.nodedict[node_id].outEdges
|
| 91 |
+
branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences
|
| 92 |
+
|
| 93 |
+
if branching_factor > uncertainty_threshold:
|
| 94 |
+
high_uncertainty_nodes.append(node_id)
|
| 95 |
+
|
| 96 |
+
selected_labels = list(text_poa_graph._seq_paths.keys())
|
| 97 |
+
masked_candidates = {}
|
| 98 |
+
uncertain_region = False
|
| 99 |
+
for label in selected_labels:
|
| 100 |
+
text = ""
|
| 101 |
+
for node_id in text_poa_graph._seq_paths[label]:
|
| 102 |
+
if uncertain_region:
|
| 103 |
+
text += f" *START_SEPARATOR*_{node_id} "
|
| 104 |
+
if node_id in high_uncertainty_nodes:
|
| 105 |
+
uncertain_region = True
|
| 106 |
+
|
| 107 |
+
if len(text_poa_graph.nodedict[node_id].variations) > 0:
|
| 108 |
+
text += text_poa_graph.nodedict[node_id].variations[label]
|
| 109 |
+
text += " "
|
| 110 |
+
else:
|
| 111 |
+
text += text_poa_graph.nodedict[node_id].text
|
| 112 |
+
text += " "
|
| 113 |
+
|
| 114 |
+
if uncertain_region and node_id not in high_uncertainty_nodes:
|
| 115 |
+
text += f" *END_SEPARATOR*_{node_id} "
|
| 116 |
+
uncertain_region = False
|
| 117 |
+
masked_candidates[label] = text
|
| 118 |
+
|
| 119 |
+
patch_start_node = None
|
| 120 |
+
uncertain_ids = []
|
| 121 |
+
|
| 122 |
+
# give a grace period for the first incorrect step
|
| 123 |
+
prev_step = {label: None for label in selected_labels}
|
| 124 |
+
|
| 125 |
+
for node_id in high_uncertainty_nodes:
|
| 126 |
+
uncertain_ids.append(node_id)
|
| 127 |
+
context_before = extract_context(text_poa_graph, node_id)
|
| 128 |
+
alternative_paths = extract_alternative_paths(text_poa_graph, node_id)
|
| 129 |
+
equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels)
|
| 130 |
+
new_labels = selected_labels.copy()
|
| 131 |
+
|
| 132 |
+
# Only do self-verifaction for labels from different sematically equivalent branches
|
| 133 |
+
if len(equivalent_classes) <= 1:
|
| 134 |
+
continue
|
| 135 |
+
i = 0
|
| 136 |
+
while i < len(equivalent_classes):
|
| 137 |
+
if i + 1 < len(equivalent_classes):
|
| 138 |
+
label_a = equivalent_classes[i][0]
|
| 139 |
+
label_b = equivalent_classes[i + 1][0]
|
| 140 |
+
full_a = context_before[label_a] + alternative_paths[label_a]
|
| 141 |
+
full_b = context_before[label_b] + alternative_paths[label_b]
|
| 142 |
+
|
| 143 |
+
score = verify_correctness_pairwise(
|
| 144 |
+
full_text_1=full_a,
|
| 145 |
+
full_text_2=full_b,
|
| 146 |
+
verification_model=verification_model,
|
| 147 |
+
problem=problem,
|
| 148 |
+
api=verification_api,
|
| 149 |
+
)
|
| 150 |
+
if float(score[0]) < 1.0:
|
| 151 |
+
print(f"Label {label_a} is incorrect at node {node_id}")
|
| 152 |
+
masked_candidates[label_a] = (
|
| 153 |
+
masked_candidates[label_a]
|
| 154 |
+
.replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
|
| 155 |
+
.replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
|
| 156 |
+
)
|
| 157 |
+
if not prev_step[label_a]:
|
| 158 |
+
prev_step[label_a] = True
|
| 159 |
+
if prev_step[label_a] and grace_period or not grace_period:
|
| 160 |
+
for label_i in equivalent_classes[i]:
|
| 161 |
+
new_labels.remove(label_i)
|
| 162 |
+
print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
|
| 163 |
+
if float(score[0]) == 1.0:
|
| 164 |
+
prev_step[label_a] = False
|
| 165 |
+
if float(score[1]) < 1.0:
|
| 166 |
+
print(f"Label {label_b} is incorrect at node {node_id}")
|
| 167 |
+
masked_candidates[label_b] = (
|
| 168 |
+
masked_candidates[label_b]
|
| 169 |
+
.replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
|
| 170 |
+
.replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
|
| 171 |
+
)
|
| 172 |
+
if not prev_step[label_b]:
|
| 173 |
+
prev_step[label_b] = True
|
| 174 |
+
if prev_step[label_b] and grace_period or not grace_period:
|
| 175 |
+
for label_i in equivalent_classes[i + 1]:
|
| 176 |
+
new_labels.remove(label_i)
|
| 177 |
+
print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
|
| 178 |
+
if float(score[1]) == 1.0:
|
| 179 |
+
prev_step[label_b] = False
|
| 180 |
+
i += 2
|
| 181 |
+
else:
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
if len(new_labels) == 0:
|
| 185 |
+
patch_start_node = node_id
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
selected_labels = new_labels.copy()
|
| 189 |
+
|
| 190 |
+
# These are the pruned approaches with masking
|
| 191 |
+
print(masked_candidates)
|
| 192 |
+
masked_approaches = "\n".join(
|
| 193 |
+
[
|
| 194 |
+
f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}"
|
| 195 |
+
for label in selected_labels
|
| 196 |
+
]
|
| 197 |
+
)
|
| 198 |
+
# These are all approaches with masking
|
| 199 |
+
all_approaches = "\n".join(
|
| 200 |
+
[f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
default_prompt = f"""
|
| 204 |
+
Solve the following math problem with mathematical precision and clarity.
|
| 205 |
+
|
| 206 |
+
Problem: {problem}
|
| 207 |
+
|
| 208 |
+
Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*).
|
| 209 |
+
These sections may contain conceptual or computational errors.
|
| 210 |
+
|
| 211 |
+
There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*.
|
| 212 |
+
A verification step indicated that these steps are highly likely to contain errors.
|
| 213 |
+
|
| 214 |
+
Potential Approaches:
|
| 215 |
+
{masked_approaches}
|
| 216 |
+
|
| 217 |
+
Your task:
|
| 218 |
+
1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses
|
| 219 |
+
If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors.
|
| 220 |
+
2. Using the sections with special markers, identify potential errors.
|
| 221 |
+
3. Develop a rigorous, step-by-step solution based on sound mathematical principles
|
| 222 |
+
4. For uncertain regions:
|
| 223 |
+
- Verify each step using algebraic or numerical validation
|
| 224 |
+
- If correct, incorporate these steps with appropriate justification
|
| 225 |
+
- If incorrect, provide clear corrections with mathematical reasoning for your changes
|
| 226 |
+
5. Follow a comparative approach, using the differences between approaches to identify potential errors.
|
| 227 |
+
6. Do not blindly follow the approaches, but rather use them to identify potential errors.
|
| 228 |
+
|
| 229 |
+
Guidelines for your solution:
|
| 230 |
+
- Begin with a strategic overview of your chosen approach
|
| 231 |
+
- Present each mathematical step with clear notation and justification
|
| 232 |
+
- Pay special attention to areas that were previously marked uncertain
|
| 233 |
+
|
| 234 |
+
Conclude your solution with:
|
| 235 |
+
Therefore, the final answer is: $\\boxed{{answer}}$.
|
| 236 |
+
|
| 237 |
+
Solution:
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
patch_prompt = f"""
|
| 241 |
+
Solve the following mathematical problem with precision and clarity.
|
| 242 |
+
|
| 243 |
+
Problem: {problem}
|
| 244 |
+
|
| 245 |
+
You have been provided with several partial solution approaches that attempted to solve this problem.
|
| 246 |
+
None of these approaches are correct, but may contain valuable insights.
|
| 247 |
+
Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty.
|
| 248 |
+
A verification step indicated that these steps are likely to contain errors.
|
| 249 |
+
|
| 250 |
+
INSTRUCTIONS:
|
| 251 |
+
1. Synthesize a correct solution using insights from the previous approaches
|
| 252 |
+
2. Pay special attention to fixing the problematic areas marked by separators
|
| 253 |
+
3. Develop your solution step-by-step, showing clear mathematical reasoning
|
| 254 |
+
4. Focus especially on mathematical correctness in areas where previous solutions diverged
|
| 255 |
+
5. Present your work in a logical, sequential manner suitable for an advanced reader
|
| 256 |
+
|
| 257 |
+
GUIDELINES FOR MATHEMATICAL RIGOR:
|
| 258 |
+
1. MAINTAIN MATHEMATICAL RIGOR
|
| 259 |
+
- Verify that all mathematical operations follow from established principles and definitions
|
| 260 |
+
- Ensure dimensional consistency throughout calculations
|
| 261 |
+
- Check that algebraic manipulations preserve equality and do not introduce errors
|
| 262 |
+
|
| 263 |
+
2. CONSIDER ALTERNATIVE PERSPECTIVES
|
| 264 |
+
- Even when approaches reach the same conclusion, examine their reasoning independently
|
| 265 |
+
- Look for more elegant or insightful connections that may be missed across all approaches
|
| 266 |
+
- Consider whether fundamental mathematical principles suggest a different path
|
| 267 |
+
|
| 268 |
+
3. CRITICAL VALIDATION
|
| 269 |
+
- Test conclusions using known mathematical properties and relationships
|
| 270 |
+
- When possible, verify results using alternative methods
|
| 271 |
+
- Be especially cautious when all approaches agree on a result but use similar reasoning
|
| 272 |
+
|
| 273 |
+
4. USE PRECISION IN CORRECTIONS
|
| 274 |
+
- When correcting uncertain regions, specify exactly what was incorrect and why
|
| 275 |
+
- Provide clear mathematical justification for any changes
|
| 276 |
+
- Ensure corrections align with standard mathematical principles and notations
|
| 277 |
+
|
| 278 |
+
Previous Approaches (for reference only):
|
| 279 |
+
{all_approaches}
|
| 280 |
+
|
| 281 |
+
Your Solution:
|
| 282 |
+
[Begin with a clear statement of your approach]
|
| 283 |
+
[Provide detailed mathematical steps]
|
| 284 |
+
[Ensure correct handling of complex mathematical operations]
|
| 285 |
+
[Verify your work at key points, especially in previously problematic areas]
|
| 286 |
+
|
| 287 |
+
Always conclude with:
|
| 288 |
+
Therefore, the final answer is: $\\boxed{{answer}}$
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
if patch_start_node is not None or len(masked_candidates.keys()) == 1:
|
| 292 |
+
print("None correct, patching")
|
| 293 |
+
prompt = patch_prompt
|
| 294 |
+
else:
|
| 295 |
+
prompt = default_prompt
|
| 296 |
+
|
| 297 |
+
return self_complete(
|
| 298 |
+
verification_prompt=prompt, verification_model=verification_model, api=verification_api
|
| 299 |
+
), masked_candidates
|
src/generation_utils.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import InferenceClient
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
from together import Together
|
| 6 |
+
|
| 7 |
+
from src.text_poa_graph import TextPOAGraph
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def extract_context(text_poa_graph, node_id):
|
| 11 |
+
"""Extract context up to and including the specified node_id."""
|
| 12 |
+
contexts = {}
|
| 13 |
+
for label, path in text_poa_graph._seq_paths.items():
|
| 14 |
+
idx = path.index(node_id)
|
| 15 |
+
context = path[: idx + 1]
|
| 16 |
+
contexts[label] = " ".join(
|
| 17 |
+
text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
|
| 18 |
+
for nid in context
|
| 19 |
+
)
|
| 20 |
+
return contexts
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def extract_alternative_paths(text_poa_graph: TextPOAGraph, node_id):
|
| 24 |
+
"""Extract all alternative paths from this uncertainty point to the next consensus node."""
|
| 25 |
+
alternative_paths = {}
|
| 26 |
+
for label, path in text_poa_graph._seq_paths.items():
|
| 27 |
+
idx = path.index(node_id)
|
| 28 |
+
next_cn = None
|
| 29 |
+
for i in range(idx + 1, len(path)):
|
| 30 |
+
if path[i] in text_poa_graph.consensus_node_ids:
|
| 31 |
+
next_cn = path[i]
|
| 32 |
+
break
|
| 33 |
+
|
| 34 |
+
if next_cn:
|
| 35 |
+
next_cn_idx = path.index(next_cn)
|
| 36 |
+
alternative_segment = path[idx + 1 : next_cn_idx + 1]
|
| 37 |
+
else:
|
| 38 |
+
alternative_segment = []
|
| 39 |
+
|
| 40 |
+
alternative_paths[label] = " ".join(
|
| 41 |
+
text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
|
| 42 |
+
for nid in alternative_segment
|
| 43 |
+
)
|
| 44 |
+
return alternative_paths
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def is_same_branch(text_poa_graph: TextPOAGraph, node_id, lable_1, label_2):
|
| 48 |
+
"""Check if the next vaiable nodes for two sequences are the same after node_id."""
|
| 49 |
+
path_1 = text_poa_graph._seq_paths[lable_1]
|
| 50 |
+
path_2 = text_poa_graph._seq_paths[label_2]
|
| 51 |
+
idx_1 = path_1.index(node_id)
|
| 52 |
+
idx_2 = path_2.index(node_id)
|
| 53 |
+
return path_1[idx_1 + 1] == path_2[idx_2 + 1]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def extract_equivalent_classes(text_poa_graph: TextPOAGraph, node_id, selected_labels):
|
| 57 |
+
"""Extract equivalent classes from the text POA graph."""
|
| 58 |
+
if not selected_labels:
|
| 59 |
+
return []
|
| 60 |
+
|
| 61 |
+
equivalent_classes = []
|
| 62 |
+
for label in selected_labels:
|
| 63 |
+
matched = False
|
| 64 |
+
for class_group in equivalent_classes:
|
| 65 |
+
if is_same_branch(text_poa_graph, node_id, class_group[0], label):
|
| 66 |
+
class_group.append(label)
|
| 67 |
+
matched = True
|
| 68 |
+
break
|
| 69 |
+
if not matched:
|
| 70 |
+
equivalent_classes.append([label])
|
| 71 |
+
return equivalent_classes
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def verify_correctness_pairwise(
|
| 75 |
+
full_text_1: str, full_text_2: str, verification_model: str, problem: str, api: str = "openai"
|
| 76 |
+
):
|
| 77 |
+
"""Pairwise verification of two partial solution paths."""
|
| 78 |
+
if api == "openai":
|
| 79 |
+
client = OpenAI()
|
| 80 |
+
elif api == "hf":
|
| 81 |
+
client = InferenceClient()
|
| 82 |
+
elif api == "together":
|
| 83 |
+
client = Together()
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Invalid API: {api}")
|
| 86 |
+
|
| 87 |
+
prompt = f"""
|
| 88 |
+
You will be given a problem and 2 partial solutions.
|
| 89 |
+
Your task is to use comparison as an EFFICIENCY TOOL to quickly identify potential errors.
|
| 90 |
+
You will be given guidelines to follow, and you will be penalized if you do not follow them.
|
| 91 |
+
|
| 92 |
+
Problem: {problem}
|
| 93 |
+
|
| 94 |
+
Partial Solution 1: {full_text_1}
|
| 95 |
+
Partial Solution 2: {full_text_2}
|
| 96 |
+
|
| 97 |
+
CRITICAL GUIDELINES:
|
| 98 |
+
- DO NOT penalize a solution for being incomplete or having missing steps
|
| 99 |
+
- DO NOT make a comparison of which solution is better
|
| 100 |
+
- DO NOT consider steps incorrect just because they differ between solutions
|
| 101 |
+
- DO NOT prematurely evaluate based on final answers or future steps
|
| 102 |
+
- DO NOT expect both solutions to be at the same stage of completion
|
| 103 |
+
- DO NOT consider a step incorrect just because it lacks sufficient detail or justification
|
| 104 |
+
|
| 105 |
+
KEY EFFICIENCY PRINCIPLE:
|
| 106 |
+
- Use agreement between solutions as evidence of correctness
|
| 107 |
+
- Use disagreement as a signal to investigate more deeply
|
| 108 |
+
- Only label a step as an error if it contains a specific mathematical mistake
|
| 109 |
+
- Incompleteness is not a mathematical error.
|
| 110 |
+
|
| 111 |
+
Here are the instructions for how to complete your task:
|
| 112 |
+
|
| 113 |
+
EFFICIENT VERIFICATION APPROACH:
|
| 114 |
+
|
| 115 |
+
1. QUICK COMPARISON (Use this to focus your attention):
|
| 116 |
+
- Immediately identify where the solutions differ in approach or results
|
| 117 |
+
- Use these differences as "error hotspots" to prioritize your verification
|
| 118 |
+
- When solutions agree, you can generally assume that part is correct
|
| 119 |
+
- When solutions disagree, investigate those specific points deeply
|
| 120 |
+
|
| 121 |
+
2. TARGETED VERIFICATION (Only where needed):
|
| 122 |
+
- Most important: Do not consider any incomplete steps as errors
|
| 123 |
+
- Focus your mathematical verification on the "hotspots" identified above
|
| 124 |
+
- Check mathematical validity only at points of difference or uncertainty
|
| 125 |
+
- Avoid line-by-line checking of steps where solutions agree
|
| 126 |
+
- For each potential error spot, verify if the mathematical reasoning is valid
|
| 127 |
+
- If an intermediate step is later corrected, do not penalize the solution for having the incorrect intermediate step
|
| 128 |
+
|
| 129 |
+
After your targeted verification, propose a score tuple (score_1, score_2):
|
| 130 |
+
- Score (1,1) if both partial solutions are valid
|
| 131 |
+
- Score (1,0) if only the first solution is valid
|
| 132 |
+
- Score (0,1) if only the second solution is valid
|
| 133 |
+
- Score (0,0) if both solutions are invalid
|
| 134 |
+
|
| 135 |
+
In case you score a solution as 0, you must give an explanation for each check below:
|
| 136 |
+
3. FINAL CHECKS:
|
| 137 |
+
- If you score a solution as 0, you MUST identify the specific mathematical error.
|
| 138 |
+
- You must also double check the problem statement. Reconsider your score and determine if you have misinterpreted the problem statement.
|
| 139 |
+
- You must also check whether you have penalized a solution for being incomplete or having missing steps.
|
| 140 |
+
|
| 141 |
+
Before outputting your final score, you must answer these questions:
|
| 142 |
+
STOP! Did you give a score of 0 to a solution that was incomplete?
|
| 143 |
+
STOP! Did you penalize a solution for being incomplete or having missing steps?
|
| 144 |
+
STOP! Did you make a comparison of which solution is better?
|
| 145 |
+
STOP! Did you consider steps incorrect just because they differ between solutions?
|
| 146 |
+
STOP! Did you prematurely evaluate based on final answers?
|
| 147 |
+
STOP! Did you consider a step incorrect just because it lacks sufficient detail or justification?
|
| 148 |
+
|
| 149 |
+
Now give your final score:
|
| 150 |
+
Final score:
|
| 151 |
+
"""
|
| 152 |
+
completion = client.chat.completions.create(
|
| 153 |
+
model=verification_model,
|
| 154 |
+
messages=[
|
| 155 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 156 |
+
{"role": "user", "content": prompt},
|
| 157 |
+
],
|
| 158 |
+
temperature=0.0,
|
| 159 |
+
)
|
| 160 |
+
response = completion.choices[0].message.content.strip()
|
| 161 |
+
print(full_text_1)
|
| 162 |
+
print(full_text_2)
|
| 163 |
+
print(f"Correctness score: {response} \n")
|
| 164 |
+
score_match = re.findall(r"\(\s*([01](?:\.0)?)\s*,\s*([01](?:\.0)?)\s*\)", response)
|
| 165 |
+
score = score_match[-1] if score_match else (0, 0)
|
| 166 |
+
return score
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def self_complete(verification_prompt: str, verification_model: str, api: str = "openai"):
|
| 170 |
+
print(verification_prompt)
|
| 171 |
+
"""Completetion method"""
|
| 172 |
+
if api == "openai":
|
| 173 |
+
client = OpenAI()
|
| 174 |
+
elif api == "hf":
|
| 175 |
+
client = InferenceClient()
|
| 176 |
+
elif api == "together":
|
| 177 |
+
client = Together()
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError(f"Invalid API: {api}")
|
| 180 |
+
|
| 181 |
+
completion = client.chat.completions.create(
|
| 182 |
+
model=verification_model,
|
| 183 |
+
messages=[
|
| 184 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 185 |
+
{"role": "user", "content": verification_prompt},
|
| 186 |
+
],
|
| 187 |
+
temperature=0.0,
|
| 188 |
+
)
|
| 189 |
+
response = completion.choices[0].message.content.strip()
|
| 190 |
+
return response
|
src/global_edit_utils.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
from openai import OpenAI
|
| 3 |
+
|
| 4 |
+
bio_prompt = """
|
| 5 |
+
You are given a piece of text that is a part of a biography of an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 6 |
+
Then, remove any redundant information.
|
| 7 |
+
Text: {text}
|
| 8 |
+
|
| 9 |
+
If this is not possible because the text is just a fragment of a sentence, return "Abstain".
|
| 10 |
+
If the text already claims a lack of knowledge about the topic, return "Abstain".
|
| 11 |
+
Only return the cleaned up text. Do not include any other text:
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
fp_prompt = """
|
| 15 |
+
You are given a piece of text that is a part of a false presupposition task which includes outputting a list of items.
|
| 16 |
+
This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 17 |
+
Then, remove any redundant information.
|
| 18 |
+
Text: {text}
|
| 19 |
+
|
| 20 |
+
The resulting list of items should be separated by semicolons with no other text.
|
| 21 |
+
If this list it not possible to generate, return "Abstain".
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
hist_prompt = """
|
| 25 |
+
You are given a piece of text that is a part of a historical event task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 26 |
+
Then, remove any redundant information.
|
| 27 |
+
Text: {text}
|
| 28 |
+
|
| 29 |
+
If this is not possible because the text is just a fragment of a sentence, return "Abstain".
|
| 30 |
+
If the text already claims a lack of knowledge about the topic, return "Abstain".
|
| 31 |
+
Only return the cleaned up text. Do not include any other text:
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
refs_prompt = """
|
| 35 |
+
You are given a piece of text that is a part of a reference task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 36 |
+
Then, remove any redundant information.
|
| 37 |
+
Text: {text}
|
| 38 |
+
|
| 39 |
+
If this is not possible because the text is just a fragment of a sentence, return "Abstain".
|
| 40 |
+
If the text already claims a lack of knowledge about the topic, return "Abstain".
|
| 41 |
+
Only return the cleaned up text. Do not include any other text:
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
gpqa_prompt = """
|
| 45 |
+
You are given a piece of text that is a part of a graduate level question answering task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 46 |
+
Then, remove any redundant information.
|
| 47 |
+
Text: {text}
|
| 48 |
+
Only return the cleaned up text. Do not include any other text:
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
popqa_prompt = """
|
| 52 |
+
You are given a piece of text that is a part of a paragraph which details facts related to an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
|
| 53 |
+
Then, remove any redundant information.
|
| 54 |
+
Text: {text}
|
| 55 |
+
|
| 56 |
+
If this is not possible because the text is just a fragment of a sentence, return "Abstain".
|
| 57 |
+
If the text already claims a lack of knowledge about the topic, return "Abstain".
|
| 58 |
+
Only return the cleaned up text. Do not include any other text:
|
| 59 |
+
"""
|
| 60 |
+
task_to_prompt = {
|
| 61 |
+
"bio": bio_prompt,
|
| 62 |
+
"fp": fp_prompt,
|
| 63 |
+
"hist": hist_prompt,
|
| 64 |
+
"refs": refs_prompt,
|
| 65 |
+
"gpqa": gpqa_prompt,
|
| 66 |
+
"popqa": popqa_prompt
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
'''
|
| 70 |
+
Cleans up disfluencies in the draft response in consensus decoding.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
text: The text to clean up.
|
| 74 |
+
api: The API to use for cleaning up the text.
|
| 75 |
+
task: The task : biography, false presupposition, historical event, reference, graduate question answering, paragraph question answering.
|
| 76 |
+
model: The model to use for cleaning up the text.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
A string of the cleaned up text.
|
| 80 |
+
'''
|
| 81 |
+
|
| 82 |
+
def clean_up_text(text: str, api: str, task: str, model: str = "gpt-4.1-mini", **kwargs):
|
| 83 |
+
if api == "openai":
|
| 84 |
+
client = OpenAI()
|
| 85 |
+
elif api == "hf":
|
| 86 |
+
tokenizer = kwargs.get("tokenizer")
|
| 87 |
+
model = kwargs.get("hf_model")
|
| 88 |
+
|
| 89 |
+
if tokenizer is None or model is None:
|
| 90 |
+
raise ValueError("For 'hf', both 'tokenizer' and 'model' must be provided.")
|
| 91 |
+
|
| 92 |
+
clean_up_prompt = task_to_prompt[task].format(text=text)
|
| 93 |
+
|
| 94 |
+
messages = [{"role": "user", "content": clean_up_prompt}]
|
| 95 |
+
input_ids = tokenizer.apply_chat_template(
|
| 96 |
+
messages,
|
| 97 |
+
add_generation_prompt=True,
|
| 98 |
+
return_tensors="pt"
|
| 99 |
+
).to(model.device)
|
| 100 |
+
|
| 101 |
+
terminators = [ tokenizer.eos_token_id, ]
|
| 102 |
+
outputs = model.generate(
|
| 103 |
+
input_ids,
|
| 104 |
+
max_new_tokens=500,
|
| 105 |
+
do_sample=False,
|
| 106 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 107 |
+
eos_token_id=terminators,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return tokenizer.decode(
|
| 111 |
+
outputs[0][input_ids.shape[-1]:],
|
| 112 |
+
skip_special_tokens=True
|
| 113 |
+
).strip()
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"Invalid API: {api}")
|
| 116 |
+
|
| 117 |
+
clean_up_prompt = task_to_prompt[task].format(text=text)
|
| 118 |
+
|
| 119 |
+
completion = client.chat.completions.create(
|
| 120 |
+
model=model,
|
| 121 |
+
messages=[
|
| 122 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 123 |
+
{"role": "user", "content": clean_up_prompt},
|
| 124 |
+
],
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return completion.choices[0].message.content.strip()
|
src/new_alignment.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ScoreParam:
|
| 5 |
+
def __init__(self, match, mismatch, gap_open, gap_extend):
|
| 6 |
+
self.match = match
|
| 7 |
+
self.mismatch = mismatch
|
| 8 |
+
self.gap_open = gap_open
|
| 9 |
+
self.gap_extend = gap_extend
|
| 10 |
+
|
| 11 |
+
def __str__(self):
|
| 12 |
+
return f"Match: {self.match}, Mismatch: {self.mismatch}, Gap Open: {self.gap_open}, Gap Extend: {self.gap_extend}"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SeqGraphAlignment(object):
|
| 16 |
+
__default_score = ScoreParam(1, -3, -2, -1)
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
sequence,
|
| 21 |
+
graph,
|
| 22 |
+
fastMethod=True,
|
| 23 |
+
globalAlign=False,
|
| 24 |
+
score_params=__default_score,
|
| 25 |
+
*args,
|
| 26 |
+
**kwargs,
|
| 27 |
+
):
|
| 28 |
+
self.score = score_params
|
| 29 |
+
self.sequence = sequence
|
| 30 |
+
self.graph = graph
|
| 31 |
+
self.stringidxs = None
|
| 32 |
+
self.nodeidxs = None
|
| 33 |
+
self.globalAlign = globalAlign
|
| 34 |
+
if fastMethod:
|
| 35 |
+
matches = self.alignStringToGraphFast(*args, **kwargs)
|
| 36 |
+
else:
|
| 37 |
+
matches = self.alignStringToGraphSimple(*args, **kwargs)
|
| 38 |
+
self.stringidxs, self.nodeidxs = matches
|
| 39 |
+
|
| 40 |
+
def alignmentStrings(self):
|
| 41 |
+
return (
|
| 42 |
+
"".join(self.sequence[i] if i is not None else "-" for i in self.stringidxs),
|
| 43 |
+
"".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def matchscore(self, c1, c2):
|
| 47 |
+
if c1 == c2:
|
| 48 |
+
return self.score.match
|
| 49 |
+
else:
|
| 50 |
+
return self.score.mismatch
|
| 51 |
+
|
| 52 |
+
def matchscoreVec(self, c, v):
|
| 53 |
+
return numpy.where(v == c, self.score.match, self.score.mismatch)
|
| 54 |
+
|
| 55 |
+
def prevIndices(self, node, nodeIDtoIndex):
|
| 56 |
+
prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
|
| 57 |
+
if not prev:
|
| 58 |
+
prev = [-1]
|
| 59 |
+
return prev
|
| 60 |
+
|
| 61 |
+
def initializeDynamicProgrammingData(self):
|
| 62 |
+
l1 = self.graph.nNodes
|
| 63 |
+
l2 = len(self.sequence)
|
| 64 |
+
|
| 65 |
+
nodeIDtoIndex = {}
|
| 66 |
+
nodeIndexToID = {-1: None}
|
| 67 |
+
ni = self.graph.nodeiterator()
|
| 68 |
+
for index, node in enumerate(ni()):
|
| 69 |
+
nodeIDtoIndex[node.ID] = index
|
| 70 |
+
nodeIndexToID[index] = node.ID
|
| 71 |
+
|
| 72 |
+
scores = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 73 |
+
|
| 74 |
+
if self.globalAlign:
|
| 75 |
+
# M[0, i] = -inf
|
| 76 |
+
scores[0, 0, :] = [
|
| 77 |
+
-1000000000 for i in range(l2+1)
|
| 78 |
+
]
|
| 79 |
+
scores[0, 0, 0] = 0
|
| 80 |
+
# X[0, i] = gap_open + i * gap_extend
|
| 81 |
+
scores[1, 0, :] = [
|
| 82 |
+
self.score.gap_open + i * self.score.gap_extend for i in range(l2 + 1)
|
| 83 |
+
]
|
| 84 |
+
scores[1, 0, 0] = -1000000000
|
| 85 |
+
# Y[0, i] = -inf
|
| 86 |
+
scores[2, 0, :] = [
|
| 87 |
+
-1000000000 for i in range(l2+1)
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
ni = self.graph.nodeiterator()
|
| 91 |
+
# After topology sort, the predcessors will have index less than the current node
|
| 92 |
+
for index, node in enumerate(ni()):
|
| 93 |
+
scores[0, index + 1, 0] = -1000000000
|
| 94 |
+
scores[1, index + 1, 0] = -1000000000
|
| 95 |
+
prevIdxs = self.prevIndices(node, nodeIDtoIndex)
|
| 96 |
+
best = scores[2 ,prevIdxs[0] + 1, 0]
|
| 97 |
+
for prevIdx in prevIdxs:
|
| 98 |
+
best = max(best, scores[2, prevIdx + 1, 0])
|
| 99 |
+
# If we have no predecessors, we start the gap
|
| 100 |
+
if prevIdxs == [-1]:
|
| 101 |
+
scores[2, index + 1, 0] = self.score.gap_open + self.score.gap_extend
|
| 102 |
+
else:
|
| 103 |
+
scores[2, index + 1, 0] = best + self.score.gap_extend
|
| 104 |
+
|
| 105 |
+
# 3D Backtracking
|
| 106 |
+
backStrIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 107 |
+
backGrphIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 108 |
+
backMtxIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
|
| 109 |
+
|
| 110 |
+
return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx
|
| 111 |
+
|
| 112 |
+
def backtrack(self, scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID):
|
| 113 |
+
besti, bestj = scores.shape[1] - 1, scores.shape[2] - 1
|
| 114 |
+
#Storing best matrices for each [i,j]
|
| 115 |
+
scores_arr = numpy.array(scores)
|
| 116 |
+
max_m = numpy.argmax(scores_arr, axis=0)
|
| 117 |
+
|
| 118 |
+
if self.globalAlign:
|
| 119 |
+
ni = self.graph.nodeiterator()
|
| 120 |
+
# Finding the best node to start from
|
| 121 |
+
terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
|
| 122 |
+
print(terminalIndices)
|
| 123 |
+
besti = terminalIndices[0] + 1
|
| 124 |
+
bestscore = scores[max_m[besti, bestj], besti, bestj]
|
| 125 |
+
for i in terminalIndices[1:]:
|
| 126 |
+
score = scores[max_m[i + 1, bestj], i + 1, bestj]
|
| 127 |
+
if score > bestscore:
|
| 128 |
+
bestscore, besti = score, i + 1
|
| 129 |
+
bestm = max_m[besti, bestj]
|
| 130 |
+
|
| 131 |
+
matches = []
|
| 132 |
+
strindexes = []
|
| 133 |
+
|
| 134 |
+
while (besti != 0 or bestj != 0):
|
| 135 |
+
nextm, nexti, nextj, = backMtxIdx[bestm, besti, bestj], backGrphIdx[bestm, besti, bestj], backStrIdx[bestm, besti, bestj]
|
| 136 |
+
curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
|
| 137 |
+
|
| 138 |
+
if bestm == 0:
|
| 139 |
+
matches.insert(0, curnodeidx)
|
| 140 |
+
strindexes.insert(0, curstridx)
|
| 141 |
+
elif bestm == 1:
|
| 142 |
+
matches.insert(0, None)
|
| 143 |
+
strindexes.insert(0, curstridx)
|
| 144 |
+
else:
|
| 145 |
+
matches.insert(0, curnodeidx)
|
| 146 |
+
strindexes.insert(0, None)
|
| 147 |
+
|
| 148 |
+
bestm, besti, bestj = nextm, nexti, nextj
|
| 149 |
+
|
| 150 |
+
return strindexes, matches
|
src/new_text_alignment.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from difflib import SequenceMatcher
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from .new_alignment import ScoreParam, SeqGraphAlignment
|
| 6 |
+
|
| 7 |
+
PUNCTUATION_MARKS = [".", "!", "?", ",", ":", ";", "...", "(", ")"]
|
| 8 |
+
|
| 9 |
+
class TextSeqGraphAlignment(SeqGraphAlignment):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
text,
|
| 13 |
+
graph,
|
| 14 |
+
fastMethod=True,
|
| 15 |
+
globalAlign=True,
|
| 16 |
+
matchscore=1,
|
| 17 |
+
mismatchscore=-3,
|
| 18 |
+
gap_open=-2,
|
| 19 |
+
gap_extend=-1,
|
| 20 |
+
position_weight=0.1,
|
| 21 |
+
*args,
|
| 22 |
+
**kwargs,
|
| 23 |
+
):
|
| 24 |
+
score_params = ScoreParam(
|
| 25 |
+
match=matchscore, mismatch=mismatchscore, gap_open=gap_open, gap_extend=gap_extend
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if isinstance(text, str):
|
| 29 |
+
self.original_text = text
|
| 30 |
+
self.sequence = text.split()
|
| 31 |
+
else:
|
| 32 |
+
self.sequence = text
|
| 33 |
+
self.original_text = " ".join(text)
|
| 34 |
+
self.position_weight = position_weight
|
| 35 |
+
|
| 36 |
+
super().__init__(
|
| 37 |
+
self.sequence,
|
| 38 |
+
graph,
|
| 39 |
+
fastMethod,
|
| 40 |
+
globalAlign=globalAlign,
|
| 41 |
+
score_params=score_params,
|
| 42 |
+
*args,
|
| 43 |
+
**kwargs,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def string_similarity(self, s1, s2):
|
| 47 |
+
"""Get edit-distance based similarity between two strings"""
|
| 48 |
+
return SequenceMatcher(None, s1, s2).ratio()
|
| 49 |
+
|
| 50 |
+
def matchscore(self, word1: str, word2: str) -> float:
|
| 51 |
+
"""Enhanced scoring function that considers string similarity
|
| 52 |
+
and relative position"""
|
| 53 |
+
# Calculate basic string similarity
|
| 54 |
+
similarity = self.string_similarity(word1, word2)
|
| 55 |
+
|
| 56 |
+
# If words are very similar, treat as match
|
| 57 |
+
if similarity > 0.8: # Can tune this threshold
|
| 58 |
+
similarity = self.score.match
|
| 59 |
+
# For less similar words, scale score based on similarity
|
| 60 |
+
elif similarity > 0.5: # Can tune this threshold too
|
| 61 |
+
similarity = self.score.match * similarity
|
| 62 |
+
else:
|
| 63 |
+
similarity = self.score.mismatch
|
| 64 |
+
return similarity
|
| 65 |
+
|
| 66 |
+
# add weight if any punctuation mark is present
|
| 67 |
+
if any(char in word1 for char in PUNCTUATION_MARKS) or any(
|
| 68 |
+
char in word2 for char in PUNCTUATION_MARKS
|
| 69 |
+
):
|
| 70 |
+
similarity = similarity * 1.5
|
| 71 |
+
|
| 72 |
+
return similarity
|
| 73 |
+
|
| 74 |
+
def alignmentStrings(self):
|
| 75 |
+
"""Override to handle word-based alignment"""
|
| 76 |
+
aligned_seq = [self.sequence[i] if i is not None else "-" for i in self.stringidxs]
|
| 77 |
+
aligned_graph = [
|
| 78 |
+
self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs
|
| 79 |
+
]
|
| 80 |
+
return " ".join(aligned_seq), " ".join(aligned_graph)
|
| 81 |
+
|
| 82 |
+
def alignStringToGraphFast(self):
|
| 83 |
+
if not isinstance(self.sequence, list):
|
| 84 |
+
raise TypeError("Sequence must be a list of words")
|
| 85 |
+
|
| 86 |
+
nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx = (
|
| 87 |
+
self.initializeDynamicProgrammingData()
|
| 88 |
+
)
|
| 89 |
+
# M: Match at last indices, X: Gap at last index of graph, Y: gap at last index of sequence
|
| 90 |
+
M, X, Y = 0, 1, 2
|
| 91 |
+
|
| 92 |
+
ni = self.graph.nodeiterator()
|
| 93 |
+
for i, node in enumerate(ni()):
|
| 94 |
+
gbase = node.text
|
| 95 |
+
|
| 96 |
+
for j, sbase in enumerate(self.sequence):
|
| 97 |
+
candidates_X , candidates_Y , candidates_M = [], [], []
|
| 98 |
+
candidates_X += [
|
| 99 |
+
(self.score.gap_open + self.score.gap_extend + scores[0, i + 1, j], i + 1, j, M),
|
| 100 |
+
(self.score.gap_extend + scores[1, i + 1, j], i + 1, j, X),
|
| 101 |
+
(self.score.gap_open + self.score.gap_extend + scores[2, i + 1, j], i + 1, j, Y)
|
| 102 |
+
]
|
| 103 |
+
for predIndex in self.prevIndices(node, nodeIDtoIndex):
|
| 104 |
+
candidates_Y += [
|
| 105 |
+
(self.score.gap_open + self.score.gap_extend + scores[0, predIndex + 1, j + 1] , predIndex + 1, j + 1, M),
|
| 106 |
+
(self.score.gap_open + self.score.gap_extend + scores[1, predIndex + 1, j + 1] , predIndex + 1, j + 1, X),
|
| 107 |
+
(self.score.gap_extend + scores[2, predIndex + 1, j + 1] , predIndex + 1, j + 1, Y)
|
| 108 |
+
]
|
| 109 |
+
candidates_M += [
|
| 110 |
+
(self.matchscore(sbase, gbase) + scores[0, predIndex + 1, j], predIndex + 1, j, M),
|
| 111 |
+
(self.matchscore(sbase, gbase) + scores[1, predIndex + 1, j], predIndex + 1, j, X),
|
| 112 |
+
(self.matchscore(sbase, gbase) + scores[2, predIndex + 1, j], predIndex + 1, j, Y)
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
(
|
| 116 |
+
scores[0, i + 1, j + 1],
|
| 117 |
+
backGrphIdx[0, i + 1, j + 1],
|
| 118 |
+
backStrIdx[0, i + 1, j + 1],
|
| 119 |
+
backMtxIdx[0, i + 1, j + 1],
|
| 120 |
+
) = max(candidates_M)
|
| 121 |
+
(
|
| 122 |
+
scores[1, i + 1, j + 1],
|
| 123 |
+
backGrphIdx[1, i + 1, j + 1],
|
| 124 |
+
backStrIdx[1, i + 1, j + 1],
|
| 125 |
+
backMtxIdx[1, i + 1, j + 1],
|
| 126 |
+
) = max(candidates_X)
|
| 127 |
+
(
|
| 128 |
+
scores[2, i + 1, j + 1],
|
| 129 |
+
backGrphIdx[2, i + 1, j + 1],
|
| 130 |
+
backStrIdx[2, i + 1, j + 1],
|
| 131 |
+
backMtxIdx[2, i + 1, j + 1],
|
| 132 |
+
) = max(candidates_Y)
|
| 133 |
+
|
| 134 |
+
return self.backtrack(scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID)
|
src/poa_graph.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from Jonathan Dursi
|
| 3 |
+
https://github.com/ljdursi/poapy
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import collections
|
| 7 |
+
import textwrap
|
| 8 |
+
from typing import Dict, List, Optional, Union
|
| 9 |
+
|
| 10 |
+
import numpy
|
| 11 |
+
|
| 12 |
+
from .alignment import SeqGraphAlignment
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Node(object):
|
| 16 |
+
def __init__(self, nodeID: int = -1, text: str = ""):
|
| 17 |
+
self.ID = nodeID
|
| 18 |
+
self.text = text
|
| 19 |
+
self.inEdges = {}
|
| 20 |
+
self.outEdges = {}
|
| 21 |
+
self.alignedTo = []
|
| 22 |
+
|
| 23 |
+
def __str__(self):
|
| 24 |
+
return "(%d:%s)" % (self.ID, self.text)
|
| 25 |
+
|
| 26 |
+
def _add_edge(
|
| 27 |
+
self,
|
| 28 |
+
edgeset: Dict[int, "Node"],
|
| 29 |
+
neighbourID: int,
|
| 30 |
+
label: Union[int, List[int]],
|
| 31 |
+
from_neighbour: bool,
|
| 32 |
+
weight: int = 1,
|
| 33 |
+
):
|
| 34 |
+
if neighbourID is None:
|
| 35 |
+
return
|
| 36 |
+
# already present? just update labels
|
| 37 |
+
# otherwise create appropriately-ordered edge and proceed
|
| 38 |
+
if neighbourID in edgeset:
|
| 39 |
+
edgeset[neighbourID].weight += weight
|
| 40 |
+
if isinstance(label, list):
|
| 41 |
+
edgeset[neighbourID].labels.extend(label)
|
| 42 |
+
else:
|
| 43 |
+
edgeset[neighbourID].labels.append(label)
|
| 44 |
+
# remove duplicates
|
| 45 |
+
edgeset[neighbourID].labels = list(set(edgeset[neighbourID].labels))
|
| 46 |
+
else:
|
| 47 |
+
if from_neighbour:
|
| 48 |
+
edge = Edge(outNodeID=neighbourID, inNodeID=self.ID, label=label, weight=weight)
|
| 49 |
+
else:
|
| 50 |
+
edge = Edge(outNodeID=self.ID, inNodeID=neighbourID, label=label, weight=weight)
|
| 51 |
+
edgeset[neighbourID] = edge
|
| 52 |
+
|
| 53 |
+
def addInEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
|
| 54 |
+
self._add_edge(self.inEdges, neighbourID, label, from_neighbour=True, weight=weight)
|
| 55 |
+
|
| 56 |
+
def addOutEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
|
| 57 |
+
self._add_edge(self.outEdges, neighbourID, label, from_neighbour=False, weight=weight)
|
| 58 |
+
|
| 59 |
+
def nextNode(self, label: int):
|
| 60 |
+
"""Returns the first (presumably only) outward neighbour
|
| 61 |
+
having the given edge label"""
|
| 62 |
+
nextID = None
|
| 63 |
+
for e in self.outEdges:
|
| 64 |
+
if label in self.outEdges[e].labels:
|
| 65 |
+
nextID = e
|
| 66 |
+
return nextID
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def inDegree(self):
|
| 70 |
+
return len(self.inEdges)
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def outDegree(self):
|
| 74 |
+
return len(self.outEdges)
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def weightedInDegree(self):
|
| 78 |
+
return sum(edge.weight for edge in self.inEdges.values())
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def weightedOutDegree(self):
|
| 82 |
+
return sum(edge.weight for edge in self.outEdges.values())
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def labels(self):
|
| 86 |
+
"""Returns all the labels associated with an in-edge or an out edge."""
|
| 87 |
+
labelset = set([])
|
| 88 |
+
for e in list(self.inEdges.values()):
|
| 89 |
+
labelset = labelset.union(e.labels)
|
| 90 |
+
for e in list(self.outEdges.values()):
|
| 91 |
+
labelset = labelset.union(e.labels)
|
| 92 |
+
return list(labelset)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Edge(object):
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
inNodeID: int = -1,
|
| 99 |
+
outNodeID: int = -1,
|
| 100 |
+
label: Optional[Union[int, List[int]]] = None,
|
| 101 |
+
weight: int = 1,
|
| 102 |
+
):
|
| 103 |
+
self.inNodeID = inNodeID
|
| 104 |
+
self.outNodeID = outNodeID
|
| 105 |
+
|
| 106 |
+
self.weight = weight
|
| 107 |
+
|
| 108 |
+
if label is None:
|
| 109 |
+
self.labels = []
|
| 110 |
+
elif isinstance(label, list):
|
| 111 |
+
self.labels = label
|
| 112 |
+
else:
|
| 113 |
+
self.labels = [label]
|
| 114 |
+
|
| 115 |
+
def addLabel(self, newlabel):
|
| 116 |
+
self.labels.append(newlabel)
|
| 117 |
+
|
| 118 |
+
def __str__(self):
|
| 119 |
+
nodestr = "(%d) -> (%d) " % (self.inNodeID, self.outNodeID)
|
| 120 |
+
if self.labels is None:
|
| 121 |
+
return nodestr
|
| 122 |
+
else:
|
| 123 |
+
return nodestr + self.labels.__str__()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class POAGraph(object):
|
| 127 |
+
def addUnmatchedSeq(self, seq, label: int = -1, updateSequences=True):
|
| 128 |
+
"""Add a completely independant (sub)string to the graph,
|
| 129 |
+
and return node index to initial and final node"""
|
| 130 |
+
if seq is None:
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
firstID, lastID = None, None
|
| 134 |
+
neededSort = self.needsSort
|
| 135 |
+
|
| 136 |
+
for text in seq:
|
| 137 |
+
nodeID = self.addNode(text)
|
| 138 |
+
if firstID is None:
|
| 139 |
+
firstID = nodeID
|
| 140 |
+
if lastID is not None:
|
| 141 |
+
self.addEdge(lastID, nodeID, label)
|
| 142 |
+
lastID = nodeID
|
| 143 |
+
|
| 144 |
+
self._needsort = neededSort # no new order problems introduced
|
| 145 |
+
if updateSequences:
|
| 146 |
+
self._seqs.append(seq)
|
| 147 |
+
self._labels.append(label)
|
| 148 |
+
self._starts.append(firstID)
|
| 149 |
+
return firstID, lastID
|
| 150 |
+
|
| 151 |
+
def __init__(self, seq=None, label: Optional[Union[int, List[int]]] = None):
|
| 152 |
+
self._nextnodeID = 0
|
| 153 |
+
self._nnodes = 0
|
| 154 |
+
self._nedges = 0
|
| 155 |
+
self.nodedict = {}
|
| 156 |
+
self.nodeidlist = [] # allows a (partial) order to be imposed on the nodes
|
| 157 |
+
self._needsort = False
|
| 158 |
+
self._labels = []
|
| 159 |
+
self._seqs = []
|
| 160 |
+
self._starts = []
|
| 161 |
+
|
| 162 |
+
if seq is not None:
|
| 163 |
+
self.addUnmatchedSeq(seq, label)
|
| 164 |
+
|
| 165 |
+
def nodeIdxToBase(self, idx):
|
| 166 |
+
return self.nodedict[self.nodeidlist[idx]].text
|
| 167 |
+
|
| 168 |
+
def addNode(self, text):
|
| 169 |
+
nid = self._nextnodeID
|
| 170 |
+
newnode = Node(nid, text)
|
| 171 |
+
self.nodedict[nid] = newnode
|
| 172 |
+
self.nodeidlist.append(nid)
|
| 173 |
+
self._nnodes += 1
|
| 174 |
+
self._nextnodeID += 1
|
| 175 |
+
self._needsSort = True
|
| 176 |
+
return nid
|
| 177 |
+
|
| 178 |
+
def addEdge(self, start, end, label, weight: int = 1):
|
| 179 |
+
if start is None or end is None:
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
if start not in self.nodedict:
|
| 183 |
+
raise KeyError("addEdge: Start node not in graph: " + str(start))
|
| 184 |
+
if end not in self.nodedict:
|
| 185 |
+
raise KeyError("addEdge: End node not in graph: " + str(end))
|
| 186 |
+
|
| 187 |
+
oldNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
|
| 188 |
+
|
| 189 |
+
self.nodedict[start].addOutEdge(end, label, weight)
|
| 190 |
+
self.nodedict[end].addInEdge(start, label, weight)
|
| 191 |
+
|
| 192 |
+
newNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
|
| 193 |
+
|
| 194 |
+
if newNodeEdges != oldNodeEdges:
|
| 195 |
+
self._nedges += 1
|
| 196 |
+
|
| 197 |
+
self._needsSort = True
|
| 198 |
+
return
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def needsSort(self):
|
| 202 |
+
return self._needsort
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def nNodes(self):
|
| 206 |
+
return self._nnodes
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def nEdges(self):
|
| 210 |
+
return self._nedges
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def num_sequences(self):
|
| 214 |
+
return len(self._seqs)
|
| 215 |
+
|
| 216 |
+
def get_sequences(self):
|
| 217 |
+
return self._seqs
|
| 218 |
+
|
| 219 |
+
def _simplified_graph_rep(self):
|
| 220 |
+
|
| 221 |
+
node_to_pn = {}
|
| 222 |
+
pn_to_nodes = {}
|
| 223 |
+
|
| 224 |
+
# Find the mappings from nodes to pseudonodes
|
| 225 |
+
cur_pnid = 0
|
| 226 |
+
for _, node in self.nodedict.items():
|
| 227 |
+
if node.ID not in node_to_pn:
|
| 228 |
+
node_ids = [node.ID] + node.alignedTo
|
| 229 |
+
pn_to_nodes[cur_pnid] = node_ids
|
| 230 |
+
for nid in node_ids:
|
| 231 |
+
node_to_pn[nid] = cur_pnid
|
| 232 |
+
cur_pnid += 1
|
| 233 |
+
|
| 234 |
+
# create the pseudonodes
|
| 235 |
+
Pseudonode = collections.namedtuple(
|
| 236 |
+
"Pseudonode", ["pnode_id", "predecessors", "successors", "node_ids"]
|
| 237 |
+
)
|
| 238 |
+
pseudonodes = []
|
| 239 |
+
|
| 240 |
+
for pnid in range(cur_pnid):
|
| 241 |
+
nids, preds, succs = pn_to_nodes[pnid], [], []
|
| 242 |
+
for nid in nids:
|
| 243 |
+
node = self.nodedict[nid]
|
| 244 |
+
preds += [node_to_pn[inEdge.outNodeID] for _, inEdge in node.inEdges.items()]
|
| 245 |
+
succs += [node_to_pn[outEdge.inNodeID] for _, outEdge in node.outEdges.items()]
|
| 246 |
+
|
| 247 |
+
pn = Pseudonode(pnode_id=pnid, predecessors=preds, successors=succs, node_ids=nids)
|
| 248 |
+
pseudonodes.append(pn)
|
| 249 |
+
|
| 250 |
+
return pseudonodes
|
| 251 |
+
|
| 252 |
+
def toposort(self):
|
| 253 |
+
"""Sorts node list so that all incoming edges come from nodes earlier in the list."""
|
| 254 |
+
sortedlist = []
|
| 255 |
+
completed = set([])
|
| 256 |
+
|
| 257 |
+
#
|
| 258 |
+
# The topological sort of this graph is complicated by the alignedTo edges;
|
| 259 |
+
# we want to nodes connected by such edges to remain near each other in the
|
| 260 |
+
# topological sort.
|
| 261 |
+
#
|
| 262 |
+
# Here we'll create a simple version of the graph that merges nodes that
|
| 263 |
+
# are alignedTo each other, performs the sort, and then decomposes the
|
| 264 |
+
# 'pseudonodes'.
|
| 265 |
+
#
|
| 266 |
+
# The need for this suggests that the way the graph is currently represented
|
| 267 |
+
# isn't quite right and needs some rethinking.
|
| 268 |
+
#
|
| 269 |
+
|
| 270 |
+
pseudonodes = self._simplified_graph_rep()
|
| 271 |
+
|
| 272 |
+
def dfs(start, complete, sortedlist):
|
| 273 |
+
stack, started = [start], set()
|
| 274 |
+
while stack:
|
| 275 |
+
pnodeID = stack.pop()
|
| 276 |
+
|
| 277 |
+
if pnodeID in complete:
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
if pnodeID in started:
|
| 281 |
+
complete.add(pnodeID)
|
| 282 |
+
for nid in pseudonodes[pnodeID].node_ids:
|
| 283 |
+
sortedlist.insert(0, nid)
|
| 284 |
+
started.remove(pnodeID)
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
successors = pseudonodes[pnodeID].successors
|
| 288 |
+
started.add(pnodeID)
|
| 289 |
+
stack.append(pnodeID)
|
| 290 |
+
stack.extend(successors)
|
| 291 |
+
|
| 292 |
+
while len(sortedlist) < self.nNodes:
|
| 293 |
+
found = None
|
| 294 |
+
for pnid in range(len(pseudonodes)):
|
| 295 |
+
if pnid not in completed and len(pseudonodes[pnid].predecessors) == 0:
|
| 296 |
+
found = pnid
|
| 297 |
+
break
|
| 298 |
+
assert found is not None
|
| 299 |
+
dfs(found, completed, sortedlist)
|
| 300 |
+
|
| 301 |
+
assert len(sortedlist) == self.nNodes
|
| 302 |
+
self.nodeidlist = sortedlist
|
| 303 |
+
self._needsSort = False
|
| 304 |
+
return
|
| 305 |
+
|
| 306 |
+
def testsort(self):
|
| 307 |
+
"""Test the nodeidlist to make sure it is topologically sorted:
|
| 308 |
+
eg, all predecessors of a node preceed the node in the list"""
|
| 309 |
+
if self.nodeidlist is None:
|
| 310 |
+
return
|
| 311 |
+
seen_nodes = set()
|
| 312 |
+
for nodeidx in self.nodeidlist:
|
| 313 |
+
node = self.nodedict[nodeidx]
|
| 314 |
+
for in_neighbour in node.inEdges:
|
| 315 |
+
assert in_neighbour in seen_nodes
|
| 316 |
+
seen_nodes.add(nodeidx)
|
| 317 |
+
return
|
| 318 |
+
|
| 319 |
+
def nodeiterator(self):
|
| 320 |
+
if self.needsSort:
|
| 321 |
+
self.toposort()
|
| 322 |
+
|
| 323 |
+
def nodegenerator():
|
| 324 |
+
for nodeidx in self.nodeidlist:
|
| 325 |
+
yield self.nodedict[nodeidx]
|
| 326 |
+
|
| 327 |
+
return nodegenerator
|
| 328 |
+
|
| 329 |
+
def __str__(self):
|
| 330 |
+
selfstr = ""
|
| 331 |
+
ni = self.nodeiterator()
|
| 332 |
+
for node in ni():
|
| 333 |
+
selfstr += node.__str__() + "\n"
|
| 334 |
+
for outIdx in node.outEdges:
|
| 335 |
+
selfstr += " " + node.outEdges[outIdx].__str__() + "\n"
|
| 336 |
+
return selfstr
|
| 337 |
+
|
| 338 |
+
def incorporateSeqAlignment(self, alignment: SeqGraphAlignment, seq, label: int = -1):
|
| 339 |
+
"""Incorporate a SeqGraphAlignment into the graph."""
|
| 340 |
+
newseq = alignment.sequence
|
| 341 |
+
stringidxs = alignment.stringidxs
|
| 342 |
+
nodeidxs = alignment.nodeidxs
|
| 343 |
+
|
| 344 |
+
firstID = None
|
| 345 |
+
headID = None
|
| 346 |
+
tailID = None
|
| 347 |
+
|
| 348 |
+
path = []
|
| 349 |
+
# head, tail of sequence may be unaligned; just add those into the
|
| 350 |
+
# graph directly
|
| 351 |
+
validstringidxs = [si for si in stringidxs if si is not None]
|
| 352 |
+
startSeqIdx, endSeqIdx = validstringidxs[0], validstringidxs[-1]
|
| 353 |
+
if startSeqIdx > 0:
|
| 354 |
+
firstID, headID = self.addUnmatchedSeq(
|
| 355 |
+
newseq[0:startSeqIdx], label, updateSequences=False
|
| 356 |
+
)
|
| 357 |
+
if endSeqIdx < len(newseq):
|
| 358 |
+
tailID, __ = self.addUnmatchedSeq(newseq[endSeqIdx + 1 :], label, updateSequences=False)
|
| 359 |
+
|
| 360 |
+
# now we march along the aligned part. For each text, we find or create
|
| 361 |
+
# a node in the graph:
|
| 362 |
+
# - if unmatched, the corresponding node is a new node
|
| 363 |
+
# - if matched:
|
| 364 |
+
# - if matched to a node with the same text, the node is that node
|
| 365 |
+
# - if matched to a node with a different text whch is in turn
|
| 366 |
+
# aligned to a node with the same text, that aligned node is
|
| 367 |
+
# the node
|
| 368 |
+
# - otherwise, we create a new node.
|
| 369 |
+
# In all cases, we create edges (or add labels) threading through the
|
| 370 |
+
# nodes.
|
| 371 |
+
for sindex, matchID in zip(stringidxs, nodeidxs):
|
| 372 |
+
if sindex is None:
|
| 373 |
+
continue
|
| 374 |
+
text = newseq[sindex]
|
| 375 |
+
if matchID is None:
|
| 376 |
+
nodeID = self.addNode(text)
|
| 377 |
+
elif self.nodedict[matchID].text == text:
|
| 378 |
+
nodeID = matchID
|
| 379 |
+
else:
|
| 380 |
+
otherAligns = self.nodedict[matchID].alignedTo
|
| 381 |
+
foundNode = None
|
| 382 |
+
for otherNodeID in otherAligns:
|
| 383 |
+
if self.nodedict[otherNodeID].text == text:
|
| 384 |
+
foundNode = otherNodeID
|
| 385 |
+
if foundNode is None:
|
| 386 |
+
nodeID = self.addNode(text)
|
| 387 |
+
self.nodedict[nodeID].alignedTo = [matchID] + otherAligns
|
| 388 |
+
for otherNodeID in [matchID] + otherAligns:
|
| 389 |
+
self.nodedict[otherNodeID].alignedTo.append(nodeID)
|
| 390 |
+
else:
|
| 391 |
+
nodeID = foundNode
|
| 392 |
+
|
| 393 |
+
self.addEdge(headID, nodeID, label)
|
| 394 |
+
headID = nodeID
|
| 395 |
+
if firstID is None:
|
| 396 |
+
firstID = headID
|
| 397 |
+
|
| 398 |
+
path.append(nodeID)
|
| 399 |
+
|
| 400 |
+
# finished the unaligned portion: now add an edge from the current headID to the tailID.
|
| 401 |
+
self.addEdge(headID, tailID, label)
|
| 402 |
+
|
| 403 |
+
# resort
|
| 404 |
+
self.toposort()
|
| 405 |
+
|
| 406 |
+
self._seqs.append(seq)
|
| 407 |
+
self._labels.append(label)
|
| 408 |
+
self._starts.append(firstID)
|
| 409 |
+
self._seq_paths[label] = path
|
| 410 |
+
return
|
| 411 |
+
|
| 412 |
+
def consensus(self, excludeLabels=None):
|
| 413 |
+
if excludeLabels is None:
|
| 414 |
+
excludeLabels = []
|
| 415 |
+
|
| 416 |
+
if self.needsSort:
|
| 417 |
+
self.toposort()
|
| 418 |
+
|
| 419 |
+
nodesInReverse = self.nodeidlist[::-1]
|
| 420 |
+
maxnodeID = max(nodesInReverse) + 1
|
| 421 |
+
nextInPath = [-1] * maxnodeID
|
| 422 |
+
scores = numpy.zeros((maxnodeID))
|
| 423 |
+
|
| 424 |
+
for nodeID in nodesInReverse:
|
| 425 |
+
bestWeightScoreEdge = (-1, -1, None)
|
| 426 |
+
for neighbourID in self.nodedict[nodeID].outEdges:
|
| 427 |
+
# print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
|
| 428 |
+
e = self.nodedict[nodeID].outEdges[neighbourID]
|
| 429 |
+
weightScoreEdge = (e.weight, scores[neighbourID], neighbourID)
|
| 430 |
+
|
| 431 |
+
if weightScoreEdge > bestWeightScoreEdge:
|
| 432 |
+
bestWeightScoreEdge = weightScoreEdge
|
| 433 |
+
|
| 434 |
+
scores[nodeID] = sum(bestWeightScoreEdge[0:2])
|
| 435 |
+
nextInPath[nodeID] = bestWeightScoreEdge[2]
|
| 436 |
+
|
| 437 |
+
pos = numpy.argmax(scores)
|
| 438 |
+
path = []
|
| 439 |
+
bases = []
|
| 440 |
+
labels = []
|
| 441 |
+
while pos is not None and pos > -1:
|
| 442 |
+
path.append(pos)
|
| 443 |
+
bases.append(self.nodedict[pos].text)
|
| 444 |
+
labels.append(self.nodedict[pos].labels)
|
| 445 |
+
pos = nextInPath[pos]
|
| 446 |
+
|
| 447 |
+
# ignore END node
|
| 448 |
+
path = path[:-1]
|
| 449 |
+
bases = bases[:-1]
|
| 450 |
+
labels = labels[:-1]
|
| 451 |
+
return path, bases, labels
|
| 452 |
+
|
| 453 |
+
def allConsenses(self, maxfraction=0.5):
|
| 454 |
+
allpaths = []
|
| 455 |
+
allbases = []
|
| 456 |
+
alllabels = []
|
| 457 |
+
exclusions = []
|
| 458 |
+
|
| 459 |
+
passno = 0
|
| 460 |
+
lastlen = 1000
|
| 461 |
+
maxpasses = 10
|
| 462 |
+
|
| 463 |
+
while len(exclusions) < len(self._labels) and lastlen >= 10 and passno < maxpasses:
|
| 464 |
+
path, bases, labellists = self.consensus(exclusions)
|
| 465 |
+
if len(path) > 0:
|
| 466 |
+
allpaths.append(path)
|
| 467 |
+
allbases.append(bases)
|
| 468 |
+
alllabels.append(labellists)
|
| 469 |
+
|
| 470 |
+
labelcounts = collections.defaultdict(int)
|
| 471 |
+
for ll in labellists:
|
| 472 |
+
for label in ll:
|
| 473 |
+
labelcounts[label] += 1
|
| 474 |
+
|
| 475 |
+
for label, seq in zip(self._labels, self._seqs):
|
| 476 |
+
if label in labelcounts and labelcounts[label] >= maxfraction * len(seq):
|
| 477 |
+
exclusions.append(label)
|
| 478 |
+
|
| 479 |
+
lastlen = len(path)
|
| 480 |
+
passno += 1
|
| 481 |
+
|
| 482 |
+
return list(zip(allpaths, allbases, alllabels))
|
| 483 |
+
|
| 484 |
+
def generateAlignmentStrings(self):
|
| 485 |
+
"""Return a list of strings corresponding to the alignments in the graph"""
|
| 486 |
+
|
| 487 |
+
# Step 1: assign node IDs to columns in the output
|
| 488 |
+
# column_index[node.ID] is the position in the toposorted node list
|
| 489 |
+
# of the node itself, or the earliest node it is aligned to.
|
| 490 |
+
column_index = {}
|
| 491 |
+
current_column = 0
|
| 492 |
+
|
| 493 |
+
# go through nodes in toposort order
|
| 494 |
+
ni = self.nodeiterator()
|
| 495 |
+
for node in ni():
|
| 496 |
+
other_columns = [
|
| 497 |
+
column_index[other] for other in node.alignedTo if other in column_index
|
| 498 |
+
]
|
| 499 |
+
if other_columns:
|
| 500 |
+
found_idx = min(other_columns)
|
| 501 |
+
else:
|
| 502 |
+
found_idx = current_column
|
| 503 |
+
current_column += 1
|
| 504 |
+
|
| 505 |
+
column_index[node.ID] = found_idx
|
| 506 |
+
|
| 507 |
+
ncolumns = current_column
|
| 508 |
+
|
| 509 |
+
# Step 2: given the column indexes, populate the strings
|
| 510 |
+
# corresponding to the sequences inserted in the graph
|
| 511 |
+
seqnames = []
|
| 512 |
+
alignstrings = []
|
| 513 |
+
for label, start in zip(self._labels, self._starts):
|
| 514 |
+
seqnames.append(label)
|
| 515 |
+
curnode_id = start
|
| 516 |
+
charlist = ["-"] * ncolumns
|
| 517 |
+
while curnode_id is not None:
|
| 518 |
+
node = self.nodedict[curnode_id]
|
| 519 |
+
charlist[column_index[curnode_id]] = node.text
|
| 520 |
+
curnode_id = node.nextNode(label)
|
| 521 |
+
alignstrings.append("".join(charlist))
|
| 522 |
+
|
| 523 |
+
# Step 3: Same as step 2, but with consensus sequences
|
| 524 |
+
consenses = self.allConsenses()
|
| 525 |
+
for i, consensus in enumerate(consenses):
|
| 526 |
+
seqnames.append("Consensus" + str(i))
|
| 527 |
+
charlist = ["-"] * ncolumns
|
| 528 |
+
for path, text in zip(consensus[0], consensus[1]):
|
| 529 |
+
charlist[column_index[path]] = text
|
| 530 |
+
alignstrings.append("".join(charlist))
|
| 531 |
+
|
| 532 |
+
return list(zip(seqnames, alignstrings))
|
| 533 |
+
|
| 534 |
+
def jsOutput(self, verbose: bool = False, annotate_consensus: bool = True):
|
| 535 |
+
"""returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""
|
| 536 |
+
|
| 537 |
+
# get the consensus sequence, which we'll use as the "spine" of the
|
| 538 |
+
# graph
|
| 539 |
+
pathdict = {}
|
| 540 |
+
if annotate_consensus:
|
| 541 |
+
path, __, __ = self.consensus()
|
| 542 |
+
lines = ["var nodes = ["]
|
| 543 |
+
|
| 544 |
+
ni = self.nodeiterator()
|
| 545 |
+
count = 0
|
| 546 |
+
for node in ni():
|
| 547 |
+
line = " {id:" + str(node.ID) + ', label: "' + str(node.ID) + ": " + node.text + '"'
|
| 548 |
+
if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
|
| 549 |
+
line += (
|
| 550 |
+
", x: "
|
| 551 |
+
+ str(pathdict[node.ID])
|
| 552 |
+
+ ", y: 0 , fixed: { x:true, y:false},"
|
| 553 |
+
+ "color: '#7BE141', is_consensus:true},"
|
| 554 |
+
)
|
| 555 |
+
else:
|
| 556 |
+
line += "},"
|
| 557 |
+
lines.append(line)
|
| 558 |
+
|
| 559 |
+
lines[-1] = lines[-1][:-1]
|
| 560 |
+
lines.append("];")
|
| 561 |
+
|
| 562 |
+
lines.append(" ")
|
| 563 |
+
|
| 564 |
+
lines.append("var edges = [")
|
| 565 |
+
ni = self.nodeiterator()
|
| 566 |
+
for node in ni():
|
| 567 |
+
nodeID = str(node.ID)
|
| 568 |
+
for edge in node.outEdges:
|
| 569 |
+
target = str(edge)
|
| 570 |
+
weight = str(len(node.outEdges[edge].labels) + 1.5)
|
| 571 |
+
lines.append(
|
| 572 |
+
" {from: "
|
| 573 |
+
+ nodeID
|
| 574 |
+
+ ", to: "
|
| 575 |
+
+ target
|
| 576 |
+
+ ", value: "
|
| 577 |
+
+ weight
|
| 578 |
+
+ ", color: '#4b72b0', arrows: 'to'},"
|
| 579 |
+
)
|
| 580 |
+
if verbose:
|
| 581 |
+
for alignededge in node.alignedTo:
|
| 582 |
+
# These edges indicate alignment to different bases, and are
|
| 583 |
+
# undirected; thus make sure we only plot them once:
|
| 584 |
+
if node.ID > alignededge:
|
| 585 |
+
continue
|
| 586 |
+
target = str(alignededge)
|
| 587 |
+
lines.append(
|
| 588 |
+
" {from: "
|
| 589 |
+
+ nodeID
|
| 590 |
+
+ ", to: "
|
| 591 |
+
+ target
|
| 592 |
+
+ ', value: 1, style: "dash-line", color: "red"},'
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
lines[-1] = lines[-1][:-1]
|
| 596 |
+
lines.append("];")
|
| 597 |
+
return lines
|
| 598 |
+
|
| 599 |
+
def htmlOutput(self, outfile, verbose: bool = False, annotate_consensus: bool = True):
|
| 600 |
+
header = """
|
| 601 |
+
<!doctype html>
|
| 602 |
+
<html>
|
| 603 |
+
<head>
|
| 604 |
+
<title>POA Graph Alignment</title>
|
| 605 |
+
|
| 606 |
+
<script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
|
| 607 |
+
</head>
|
| 608 |
+
|
| 609 |
+
<body>
|
| 610 |
+
|
| 611 |
+
<div id="loadingProgress">0%</div>
|
| 612 |
+
|
| 613 |
+
<div id="mynetwork"></div>
|
| 614 |
+
|
| 615 |
+
<script type="text/javascript">
|
| 616 |
+
// create a network
|
| 617 |
+
"""
|
| 618 |
+
outfile.write(textwrap.dedent(header[1:]))
|
| 619 |
+
lines = self.jsOutput(verbose=verbose, annotate_consensus=annotate_consensus)
|
| 620 |
+
for line in lines:
|
| 621 |
+
outfile.write(line + "\n")
|
| 622 |
+
footer = """
|
| 623 |
+
var container = document.getElementById('mynetwork');
|
| 624 |
+
var data= {
|
| 625 |
+
nodes: nodes,
|
| 626 |
+
edges: edges,
|
| 627 |
+
};
|
| 628 |
+
var options = {
|
| 629 |
+
width: '100%',
|
| 630 |
+
height: '800px',
|
| 631 |
+
physics: {
|
| 632 |
+
enabled: false,
|
| 633 |
+
stabilization: {
|
| 634 |
+
updateInterval: 10,
|
| 635 |
+
},
|
| 636 |
+
hierarchicalRepulsion: {
|
| 637 |
+
avoidOverlap: 0.9,
|
| 638 |
+
},
|
| 639 |
+
},
|
| 640 |
+
edges: {
|
| 641 |
+
color: {
|
| 642 |
+
inherit: false
|
| 643 |
+
}
|
| 644 |
+
},
|
| 645 |
+
layout: {
|
| 646 |
+
hierarchical: {
|
| 647 |
+
direction: "UD",
|
| 648 |
+
sortMethod: "directed",
|
| 649 |
+
shakeTowards: "roots",
|
| 650 |
+
levelSeparation: 150, // Adjust as needed
|
| 651 |
+
nodeSpacing: 100, // Adjust as needed
|
| 652 |
+
treeSpacing: 200, // Adjust as needed
|
| 653 |
+
parentCentralization: true,
|
| 654 |
+
}
|
| 655 |
+
}
|
| 656 |
+
};
|
| 657 |
+
var network = new vis.Network(container, data, options);
|
| 658 |
+
|
| 659 |
+
network.on('beforeDrawing', function(ctx) {
|
| 660 |
+
nodes.forEach(function(node) {
|
| 661 |
+
if (node.isConsensus) {
|
| 662 |
+
// Set the level of spine nodes to the bottom
|
| 663 |
+
network.body.data.nodes.update({
|
| 664 |
+
id: node.id,
|
| 665 |
+
level: 0 // Set level to 0 for spine nodes
|
| 666 |
+
});
|
| 667 |
+
}
|
| 668 |
+
});
|
| 669 |
+
});
|
| 670 |
+
|
| 671 |
+
network.on("stabilizationProgress", function (params) {
|
| 672 |
+
document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
|
| 673 |
+
});
|
| 674 |
+
network.once("stabilizationIterationsDone", function () {
|
| 675 |
+
document.getElementById("loadingProgress").innerText = "100%";
|
| 676 |
+
setTimeout(function () {
|
| 677 |
+
document.getElementById("loadingProgress").style.display = "none";
|
| 678 |
+
}, 500);
|
| 679 |
+
});
|
| 680 |
+
</script>
|
| 681 |
+
|
| 682 |
+
</body>
|
| 683 |
+
</html>
|
| 684 |
+
"""
|
| 685 |
+
outfile.write(textwrap.dedent(footer))
|
src/text_poa_graph.py
ADDED
|
@@ -0,0 +1,802 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced version of POAGraph for text alignment
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pickle
|
| 6 |
+
import textwrap
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from src.text_poa_graph_utils import path_sim_llm
|
| 13 |
+
from src.global_edit_utils import clean_up_text
|
| 14 |
+
|
| 15 |
+
from .new_text_alignment import TextSeqGraphAlignment
|
| 16 |
+
from .poa_graph import Node, POAGraph
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TextNode(Node):
|
| 20 |
+
def __init__(self, nodeID=-1, text=""):
|
| 21 |
+
super().__init__(nodeID, text)
|
| 22 |
+
self.variations = {} # Track alternate phrasings
|
| 23 |
+
self.sequences = [] # Track sequences that contain this node
|
| 24 |
+
self.influenceScore = 0
|
| 25 |
+
self.num_tokens_used = 0
|
| 26 |
+
|
| 27 |
+
def add_variation(self, text, sequence_id):
|
| 28 |
+
self.variations[sequence_id] = text
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def is_stable(self):
|
| 32 |
+
"""A node is stable if it appears frequently enough relative to total sequences"""
|
| 33 |
+
return self.frequency >= self.graph.stability_threshold
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TextPOAGraph(POAGraph):
|
| 37 |
+
def __init__(self, text=None, label=-1):
|
| 38 |
+
self.consensus_node_ids = []
|
| 39 |
+
self._seq_paths = {}
|
| 40 |
+
self.end_id = -1
|
| 41 |
+
self.start_id = -1
|
| 42 |
+
self.failed = False
|
| 43 |
+
self.num_input_tokens_used = 0
|
| 44 |
+
self.num_output_tokens_used = 0
|
| 45 |
+
super().__init__(text, label)
|
| 46 |
+
|
| 47 |
+
def addNode(self, text):
|
| 48 |
+
"""Override to use TextNode"""
|
| 49 |
+
nid = self._nextnodeID
|
| 50 |
+
newnode = TextNode(nid, text)
|
| 51 |
+
self.nodedict[nid] = newnode
|
| 52 |
+
self.nodeidlist.append(nid)
|
| 53 |
+
self._nnodes += 1
|
| 54 |
+
self._nextnodeID += 1
|
| 55 |
+
self._needsSort = True
|
| 56 |
+
return nid
|
| 57 |
+
|
| 58 |
+
def addUnmatchedSeq(self, text, label=-1, updateSequences=True):
|
| 59 |
+
"""Modified to handle text sequences"""
|
| 60 |
+
if text is None:
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
# Handle both string and list input
|
| 64 |
+
if isinstance(text, str):
|
| 65 |
+
words = text.split()
|
| 66 |
+
else:
|
| 67 |
+
words = text
|
| 68 |
+
|
| 69 |
+
firstID, lastID = None, None
|
| 70 |
+
neededSort = self.needsSort
|
| 71 |
+
|
| 72 |
+
path = []
|
| 73 |
+
for word in words:
|
| 74 |
+
nodeID = self.addNode(word)
|
| 75 |
+
if firstID is None:
|
| 76 |
+
firstID = nodeID
|
| 77 |
+
if lastID is not None:
|
| 78 |
+
self.addEdge(lastID, nodeID, label=label)
|
| 79 |
+
lastID = nodeID
|
| 80 |
+
path.append(nodeID)
|
| 81 |
+
|
| 82 |
+
self._needsort = neededSort
|
| 83 |
+
if updateSequences:
|
| 84 |
+
self._seqs.append(words)
|
| 85 |
+
self._labels.append(label)
|
| 86 |
+
self._starts.append(firstID)
|
| 87 |
+
self._seq_paths[label] = path
|
| 88 |
+
|
| 89 |
+
return firstID, lastID
|
| 90 |
+
|
| 91 |
+
def add_text(self, text, label=-1):
|
| 92 |
+
"""Main method to add new text to the alignment"""
|
| 93 |
+
if len(self._seqs) == 0:
|
| 94 |
+
# First sequence - just add it
|
| 95 |
+
self.addUnmatchedSeq(text, label)
|
| 96 |
+
else:
|
| 97 |
+
# Align to existing graph
|
| 98 |
+
alignment = TextSeqGraphAlignment(
|
| 99 |
+
text, self, matchscore=2, mismatchscore=-1, gapscore=-2
|
| 100 |
+
)
|
| 101 |
+
self.incorporateSeqAlignment(alignment, text, label)
|
| 102 |
+
|
| 103 |
+
# Update node frequencies
|
| 104 |
+
self._update_frequencies()
|
| 105 |
+
|
| 106 |
+
def removeNode(self, nodeID):
|
| 107 |
+
"""Override to handle text nodes"""
|
| 108 |
+
node = self.nodedict[nodeID]
|
| 109 |
+
if node is None:
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
# Remove all edges to this node
|
| 113 |
+
out_edges = node.outEdges.copy()
|
| 114 |
+
in_edges = node.inEdges.copy()
|
| 115 |
+
|
| 116 |
+
for edge in out_edges:
|
| 117 |
+
self.removeEdge(node.ID, edge)
|
| 118 |
+
for edge in in_edges:
|
| 119 |
+
self.removeEdge(edge, node.ID)
|
| 120 |
+
|
| 121 |
+
# Remove from graph
|
| 122 |
+
del self.nodedict[nodeID]
|
| 123 |
+
self.nodeidlist.remove(nodeID)
|
| 124 |
+
|
| 125 |
+
for path in self._seq_paths.values():
|
| 126 |
+
if nodeID in path:
|
| 127 |
+
path.remove(nodeID)
|
| 128 |
+
|
| 129 |
+
self._nnodes -= 1
|
| 130 |
+
self._needsSort = True
|
| 131 |
+
|
| 132 |
+
def removeEdge(self, nodeID1, nodeID2):
|
| 133 |
+
"""Override to handle text nodes"""
|
| 134 |
+
node1 = self.nodedict[nodeID1]
|
| 135 |
+
node2 = self.nodedict[nodeID2]
|
| 136 |
+
|
| 137 |
+
if node1 is None or node2 is None:
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
# Remove from graph
|
| 141 |
+
del node1.outEdges[nodeID2]
|
| 142 |
+
del node2.inEdges[nodeID1]
|
| 143 |
+
|
| 144 |
+
def merge_consensus_nodes(self, verbose: bool = False):
|
| 145 |
+
self.toposort()
|
| 146 |
+
# reset consensus node ids
|
| 147 |
+
self.consensus_node_ids = []
|
| 148 |
+
nodes = list(self.nodeiterator()())
|
| 149 |
+
consensus_segments = []
|
| 150 |
+
i = 0
|
| 151 |
+
while i < len(nodes):
|
| 152 |
+
node = nodes[i]
|
| 153 |
+
out_weight = sum(e.weight for e in node.outEdges.values())
|
| 154 |
+
in_weight = sum(e.weight for e in node.inEdges.values())
|
| 155 |
+
|
| 156 |
+
if out_weight in [0, self.num_sequences] and in_weight in [0, self.num_sequences]:
|
| 157 |
+
consensus_segment = [(node.ID, node.text)]
|
| 158 |
+
next_node = node
|
| 159 |
+
while (i + 1) < len(nodes) and len(next_node.outEdges) == 1:
|
| 160 |
+
next_node = nodes[i + 1]
|
| 161 |
+
next_out_weight = sum(e.weight for e in next_node.outEdges.values())
|
| 162 |
+
next_in_weight = sum(e.weight for e in next_node.inEdges.values())
|
| 163 |
+
|
| 164 |
+
if (
|
| 165 |
+
next_out_weight != self.num_sequences
|
| 166 |
+
or next_in_weight != self.num_sequences
|
| 167 |
+
):
|
| 168 |
+
break
|
| 169 |
+
|
| 170 |
+
consensus_segment.append((next_node.ID, next_node.text))
|
| 171 |
+
i += 1
|
| 172 |
+
consensus_segments.append(consensus_segment)
|
| 173 |
+
i += 1
|
| 174 |
+
# merge consensus nodes into a single node
|
| 175 |
+
for segment in consensus_segments:
|
| 176 |
+
if len(segment) == 1:
|
| 177 |
+
self.consensus_node_ids.append(segment[0][0])
|
| 178 |
+
continue
|
| 179 |
+
merged_text = " ".join([text for _, text in segment])
|
| 180 |
+
first_node_id = segment[0][0]
|
| 181 |
+
last_node_id = segment[-1][0]
|
| 182 |
+
|
| 183 |
+
self.nodedict[last_node_id].text = merged_text
|
| 184 |
+
self.consensus_node_ids.append(last_node_id)
|
| 185 |
+
|
| 186 |
+
# attach all incoming edges to first node to last node
|
| 187 |
+
for id, edge in self.nodedict[first_node_id].inEdges.items():
|
| 188 |
+
weight = edge.weight
|
| 189 |
+
for _ in range(weight):
|
| 190 |
+
self.addEdge(id, last_node_id, label=edge.labels)
|
| 191 |
+
|
| 192 |
+
# delete all nodes except last node
|
| 193 |
+
for node_id, _ in segment[:-1]:
|
| 194 |
+
self.removeNode(node_id)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if verbose:
|
| 199 |
+
print(self.consensus_node_ids)
|
| 200 |
+
|
| 201 |
+
"""
|
| 202 |
+
find all paths between start_node_id and end_node_id from original sequences
|
| 203 |
+
return a list of dictionaries with the following keys:
|
| 204 |
+
- path: list of node ids in the path (excluding start and including end)
|
| 205 |
+
- text: text of the path (excluding start and end)
|
| 206 |
+
- weight: minimal edge weight across all edges in the path
|
| 207 |
+
- labels: intersection of all edge labels in the path
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def find_paths_between(self, start_node_id: int, end_node_id: int):
|
| 211 |
+
# find all paths between start_node_id and end_node_id from original sequences
|
| 212 |
+
path_dicts = []
|
| 213 |
+
|
| 214 |
+
# keep track of visited paths to avoid duplicates
|
| 215 |
+
visited_paths = set()
|
| 216 |
+
|
| 217 |
+
for _, path in self._seq_paths.items():
|
| 218 |
+
start_index = path.index(start_node_id) if start_node_id in path else None
|
| 219 |
+
end_index = path.index(end_node_id) if end_node_id in path else None
|
| 220 |
+
|
| 221 |
+
# print(start_index, end_index)
|
| 222 |
+
# print(path)
|
| 223 |
+
|
| 224 |
+
if (
|
| 225 |
+
start_index is not None
|
| 226 |
+
and end_index is not None
|
| 227 |
+
and end_index - start_index > 1
|
| 228 |
+
and tuple(path[start_index + 1 : end_index + 1]) not in visited_paths
|
| 229 |
+
):
|
| 230 |
+
# intersection of all edge labels in the path
|
| 231 |
+
path_labels = set.intersection(
|
| 232 |
+
*[
|
| 233 |
+
set(self.nodedict[next_node_id].inEdges[node_id].labels)
|
| 234 |
+
for node_id, next_node_id in zip(
|
| 235 |
+
path[start_index:end_index], path[start_index + 1 : end_index + 1]
|
| 236 |
+
)
|
| 237 |
+
]
|
| 238 |
+
)
|
| 239 |
+
path_weight = len(path_labels)
|
| 240 |
+
path_dicts.append(
|
| 241 |
+
{
|
| 242 |
+
|
| 243 |
+
"path": path[start_index + 1 : end_index + 1],
|
| 244 |
+
"body_text": " ".join(
|
| 245 |
+
[
|
| 246 |
+
self.nodedict[node_id].text
|
| 247 |
+
for node_id in path[start_index + 1 : end_index]
|
| 248 |
+
]
|
| 249 |
+
),
|
| 250 |
+
"begin_text": self.nodedict[path[start_index]].text,
|
| 251 |
+
"end_text": self.nodedict[path[end_index]].text,
|
| 252 |
+
"weight": path_weight,
|
| 253 |
+
"labels": path_labels,
|
| 254 |
+
}
|
| 255 |
+
)
|
| 256 |
+
visited_paths.add(tuple(path[start_index + 1 : end_index + 1]))
|
| 257 |
+
|
| 258 |
+
return path_dicts
|
| 259 |
+
|
| 260 |
+
def _follow_path(self, start_id):
|
| 261 |
+
"""Follow all possible paths from a node"""
|
| 262 |
+
paths = []
|
| 263 |
+
visited = set()
|
| 264 |
+
|
| 265 |
+
def dfs(node_id, current_path):
|
| 266 |
+
if node_id in visited:
|
| 267 |
+
return
|
| 268 |
+
visited.add(node_id)
|
| 269 |
+
node = self.nodedict[node_id]
|
| 270 |
+
|
| 271 |
+
if not node.outEdges:
|
| 272 |
+
paths.append(current_path + [node_id])
|
| 273 |
+
return
|
| 274 |
+
|
| 275 |
+
for next_id in node.outEdges:
|
| 276 |
+
dfs(next_id, current_path + [node_id])
|
| 277 |
+
|
| 278 |
+
dfs(start_id, [])
|
| 279 |
+
return paths
|
| 280 |
+
|
| 281 |
+
def merge_paths_between(
|
| 282 |
+
self,
|
| 283 |
+
start_node_id: int,
|
| 284 |
+
end_node_id: int,
|
| 285 |
+
path_sim_type: str = "llm",
|
| 286 |
+
verbose: bool = False,
|
| 287 |
+
**kwargs,
|
| 288 |
+
):
|
| 289 |
+
path_dicts = self.find_paths_between(start_node_id, end_node_id)
|
| 290 |
+
|
| 291 |
+
if path_sim_type == "llm":
|
| 292 |
+
api = kwargs.get("api", "openai")
|
| 293 |
+
model = kwargs.get("model", "gpt-4o-mini")
|
| 294 |
+
domain = kwargs.get("domain", None)
|
| 295 |
+
similarity_judge_prompt = kwargs.get("similarity_judge_prompt", None)
|
| 296 |
+
|
| 297 |
+
def path_sim_func(path1_text, path2_text):
|
| 298 |
+
return path_sim_llm(
|
| 299 |
+
path1_text,
|
| 300 |
+
path2_text,
|
| 301 |
+
api=api,
|
| 302 |
+
model=model,
|
| 303 |
+
domain=domain,
|
| 304 |
+
custom_similarity_judge_prompt=similarity_judge_prompt,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
elif path_sim_type == "cosine":
|
| 308 |
+
pass
|
| 309 |
+
# embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 310 |
+
# threshold = kwargs.get("threshold", 0.9)
|
| 311 |
+
# path_sim_func = path_sim_cosine(embedding_model, threshold)
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(f"Invalid path similarity type: {path_sim_type}")
|
| 314 |
+
|
| 315 |
+
# merge paths based on semantic similarity
|
| 316 |
+
path_equivalence_classes = {}
|
| 317 |
+
class_count = 0
|
| 318 |
+
|
| 319 |
+
for path_dict in path_dicts:
|
| 320 |
+
if verbose:
|
| 321 |
+
print(path_dict)
|
| 322 |
+
found_class = False
|
| 323 |
+
for _, eq_class in path_equivalence_classes.items():
|
| 324 |
+
# check if path dict is already in an equivalence class
|
| 325 |
+
path1_text = (
|
| 326 |
+
path_dict["begin_text"]
|
| 327 |
+
+ " "
|
| 328 |
+
+ path_dict["body_text"]
|
| 329 |
+
+ " "
|
| 330 |
+
+ path_dict["end_text"]
|
| 331 |
+
)
|
| 332 |
+
path2_text = (
|
| 333 |
+
eq_class[0]["begin_text"]
|
| 334 |
+
+ " "
|
| 335 |
+
+ eq_class[0]["body_text"]
|
| 336 |
+
+ " "
|
| 337 |
+
+ eq_class[0]["end_text"]
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
judgement, num_input_tokens, num_output_tokens = path_sim_func(
|
| 341 |
+
path1_text, path2_text
|
| 342 |
+
)
|
| 343 |
+
self.num_input_tokens_used += num_input_tokens
|
| 344 |
+
self.num_output_tokens_used += num_output_tokens
|
| 345 |
+
if judgement:
|
| 346 |
+
eq_class.append(path_dict)
|
| 347 |
+
found_class = True
|
| 348 |
+
break
|
| 349 |
+
if not found_class:
|
| 350 |
+
class_count += 1
|
| 351 |
+
path_equivalence_classes[class_count] = [path_dict]
|
| 352 |
+
|
| 353 |
+
nodes_to_remove = set() # Track nodes to remove
|
| 354 |
+
for _, eq_class in path_equivalence_classes.items():
|
| 355 |
+
path_dict = eq_class[0]
|
| 356 |
+
|
| 357 |
+
if verbose:
|
| 358 |
+
print(eq_class)
|
| 359 |
+
# add new node with merged text
|
| 360 |
+
new_node_id = self.addNode(path_dict["body_text"])
|
| 361 |
+
for sequence_id in path_dict["labels"]:
|
| 362 |
+
self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]
|
| 363 |
+
|
| 364 |
+
# collect nodes to remove from first path
|
| 365 |
+
nodes_to_remove.update(path_dict["path"][:-1])
|
| 366 |
+
|
| 367 |
+
# process data regarding weights and labels
|
| 368 |
+
labels = list(path_dict["labels"])
|
| 369 |
+
weight = path_dict["weight"]
|
| 370 |
+
self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)
|
| 371 |
+
|
| 372 |
+
# Updated seq_paths for all labels to include new_node betwwen start_node and end_node
|
| 373 |
+
for label in labels:
|
| 374 |
+
index = self._seq_paths[label].index(start_node_id)
|
| 375 |
+
if (
|
| 376 |
+
index + 1 < len(self._seq_paths[label])
|
| 377 |
+
and self._seq_paths[label][index + 1] != new_node_id
|
| 378 |
+
):
|
| 379 |
+
self._seq_paths[label].insert(index + 1, new_node_id)
|
| 380 |
+
|
| 381 |
+
self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)
|
| 382 |
+
|
| 383 |
+
self.nodedict[new_node_id].sequences = labels
|
| 384 |
+
# process additional paths
|
| 385 |
+
for path_dict in eq_class[1:]:
|
| 386 |
+
for sequence_id in path_dict["labels"]:
|
| 387 |
+
self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]
|
| 388 |
+
nodes_to_remove.update(path_dict["path"][:-1])
|
| 389 |
+
|
| 390 |
+
# copy incoming edges to new node
|
| 391 |
+
labels = list(path_dict["labels"])
|
| 392 |
+
weight = path_dict["weight"]
|
| 393 |
+
self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)
|
| 394 |
+
|
| 395 |
+
# Updated seq_paths for all labels to include new_node betwwen start_node and end_node
|
| 396 |
+
for label in labels:
|
| 397 |
+
index = self._seq_paths[label].index(start_node_id)
|
| 398 |
+
if (
|
| 399 |
+
index + 1 < len(self._seq_paths[label])
|
| 400 |
+
and self._seq_paths[label][index + 1] != new_node_id
|
| 401 |
+
):
|
| 402 |
+
self._seq_paths[label].insert(index + 1, new_node_id)
|
| 403 |
+
|
| 404 |
+
self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)
|
| 405 |
+
self.nodedict[new_node_id].sequences.extend(labels)
|
| 406 |
+
|
| 407 |
+
self.nodedict[new_node_id].sequences = list(set(self.nodedict[new_node_id].sequences))
|
| 408 |
+
|
| 409 |
+
# Remove all collected nodes after processing
|
| 410 |
+
for node_id in nodes_to_remove:
|
| 411 |
+
if node_id in self.nodedict:
|
| 412 |
+
if verbose:
|
| 413 |
+
print(f"Removing node {node_id}")
|
| 414 |
+
self.removeNode(node_id)
|
| 415 |
+
|
| 416 |
+
def merge_divergent_paths(self, path_sim_type: str = "llm", verbose: bool = False, **kwargs):
|
| 417 |
+
# add dummy end node to the end of the graph
|
| 418 |
+
if not self.consensus_node_ids:
|
| 419 |
+
self.merge_consensus_nodes(verbose=verbose)
|
| 420 |
+
|
| 421 |
+
self.toposort()
|
| 422 |
+
|
| 423 |
+
if self.start_id == -1:
|
| 424 |
+
if verbose:
|
| 425 |
+
print("Adding start node")
|
| 426 |
+
self.start_id = self.addNode(text="START")
|
| 427 |
+
self._nextnodeID += 1
|
| 428 |
+
self.consensus_node_ids.insert(0, self.start_id)
|
| 429 |
+
|
| 430 |
+
for label, path in self._seq_paths.items():
|
| 431 |
+
self.addEdge(self.start_id, path[0], label=label, weight=1)
|
| 432 |
+
path.insert(0, self.start_id)
|
| 433 |
+
|
| 434 |
+
if self.end_id == -1:
|
| 435 |
+
if verbose:
|
| 436 |
+
print("Adding end node")
|
| 437 |
+
self.end_id = self.addNode(text="END")
|
| 438 |
+
self._nextnodeID += 1
|
| 439 |
+
self.consensus_node_ids = self.consensus_node_ids + [self.end_id]
|
| 440 |
+
|
| 441 |
+
for label, path in self._seq_paths.items():
|
| 442 |
+
self.addEdge(path[-1], self.end_id, label=label, weight=1)
|
| 443 |
+
path.append(self.end_id)
|
| 444 |
+
|
| 445 |
+
for i in tqdm(range(len(self.consensus_node_ids) - 1)):
|
| 446 |
+
if verbose:
|
| 447 |
+
print(self.consensus_node_ids[i], self.consensus_node_ids[i + 1])
|
| 448 |
+
self.merge_paths_between(
|
| 449 |
+
self.consensus_node_ids[i],
|
| 450 |
+
self.consensus_node_ids[i + 1],
|
| 451 |
+
path_sim_type=path_sim_type,
|
| 452 |
+
verbose=verbose,
|
| 453 |
+
**kwargs,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
def get_variable_node_ids(self):
|
| 457 |
+
return [
|
| 458 |
+
node.ID for node in self.nodedict.values() if node.ID not in self.consensus_node_ids
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
def compress_paths_between(self, start_node_id: int, end_node_id: int):
|
| 462 |
+
pass
|
| 463 |
+
|
| 464 |
+
def compress_graph(self):
|
| 465 |
+
pass
|
| 466 |
+
|
| 467 |
+
def update_influence_scores(self, outcome: Dict[int, float], discount_factor: float = 0.2):
|
| 468 |
+
self.toposort()
|
| 469 |
+
direct_scores = []
|
| 470 |
+
for node in self.nodedict.values():
|
| 471 |
+
next_out_weight = sum(e.weight for e in node.outEdges.values())
|
| 472 |
+
next_in_weight = sum(e.weight for e in node.inEdges.values())
|
| 473 |
+
if next_out_weight == self.num_sequences and next_in_weight == self.num_sequences:
|
| 474 |
+
out_list = []
|
| 475 |
+
for edge in node.outEdges.values():
|
| 476 |
+
for _ in range(len(set(edge.labels))):
|
| 477 |
+
out_list.append(np.mean([outcome[label] for label in set(edge.labels)]))
|
| 478 |
+
direct_scores.append((node.ID, np.var(out_list)))
|
| 479 |
+
|
| 480 |
+
scores = direct_scores.copy()
|
| 481 |
+
|
| 482 |
+
# Start from the end and propagate influence backward
|
| 483 |
+
for i in range(len(scores) - 2, -1, -1):
|
| 484 |
+
# Current node gets its direct influence plus discounted influence of next node
|
| 485 |
+
current_direct = scores[i][1]
|
| 486 |
+
next_total = scores[i + 1][1]
|
| 487 |
+
scores[i] = (scores[i][0], current_direct + discount_factor * next_total)
|
| 488 |
+
|
| 489 |
+
scores.sort(key=lambda x: x[1], reverse=True)
|
| 490 |
+
return scores
|
| 491 |
+
|
| 492 |
+
def jsOutput(
|
| 493 |
+
self,
|
| 494 |
+
verbose: bool = False,
|
| 495 |
+
annotate_consensus: bool = True,
|
| 496 |
+
color_annotations: Dict[int, str] = None,
|
| 497 |
+
):
|
| 498 |
+
"""returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""
|
| 499 |
+
|
| 500 |
+
# get the consensus sequence, which we'll use as the "spine" of the
|
| 501 |
+
# graph
|
| 502 |
+
pathdict = {}
|
| 503 |
+
if annotate_consensus:
|
| 504 |
+
path, __, __ = self.consensus()
|
| 505 |
+
lines = ["var nodes = ["]
|
| 506 |
+
|
| 507 |
+
ni = self.nodeiterator()
|
| 508 |
+
count = 0
|
| 509 |
+
for node in ni():
|
| 510 |
+
title_text = ""
|
| 511 |
+
if node.sequences:
|
| 512 |
+
title_text += f"Sequences: {node.sequences}"
|
| 513 |
+
if node.variations:
|
| 514 |
+
title_text += ";;;".join(
|
| 515 |
+
[f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
|
| 516 |
+
)
|
| 517 |
+
title_text = title_text.replace('"', "'")
|
| 518 |
+
line = (
|
| 519 |
+
" {id:"
|
| 520 |
+
+ str(node.ID)
|
| 521 |
+
+ ', label: "'
|
| 522 |
+
+ str(node.ID)
|
| 523 |
+
+ ": "
|
| 524 |
+
+ node.text.replace('"', "'")
|
| 525 |
+
+ '", title: '
|
| 526 |
+
+ '"'
|
| 527 |
+
+ title_text
|
| 528 |
+
+ '",'
|
| 529 |
+
)
|
| 530 |
+
if color_annotations and node.ID in color_annotations:
|
| 531 |
+
line += f" color: '{color_annotations[node.ID]}', "
|
| 532 |
+
if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
|
| 533 |
+
line += (
|
| 534 |
+
", x: "
|
| 535 |
+
+ str(pathdict[node.ID])
|
| 536 |
+
+ ", y: 0 , fixed: { x:true, y:false},"
|
| 537 |
+
+ "color: '#7BE141', is_consensus:true},"
|
| 538 |
+
)
|
| 539 |
+
else:
|
| 540 |
+
line += "},"
|
| 541 |
+
lines.append(line)
|
| 542 |
+
|
| 543 |
+
lines[-1] = lines[-1][:-1]
|
| 544 |
+
lines.append("];")
|
| 545 |
+
|
| 546 |
+
lines.append(" ")
|
| 547 |
+
|
| 548 |
+
lines.append("var edges = [ ")
|
| 549 |
+
ni = self.nodeiterator()
|
| 550 |
+
for node in ni():
|
| 551 |
+
nodeID = str(node.ID)
|
| 552 |
+
for edge in node.outEdges:
|
| 553 |
+
target = str(edge)
|
| 554 |
+
weight = str(node.outEdges[edge].weight + 1.5)
|
| 555 |
+
lines.append(
|
| 556 |
+
" {from: "
|
| 557 |
+
+ nodeID
|
| 558 |
+
+ ", to: "
|
| 559 |
+
+ target
|
| 560 |
+
+ ", value: "
|
| 561 |
+
+ weight
|
| 562 |
+
+ ", color: '#4b72b0', arrows: 'to'},"
|
| 563 |
+
)
|
| 564 |
+
if verbose:
|
| 565 |
+
for alignededge in node.alignedTo:
|
| 566 |
+
# These edges indicate alignment to different bases, and are
|
| 567 |
+
# undirected; thus make sure we only plot them once:
|
| 568 |
+
if node.ID > alignededge:
|
| 569 |
+
continue
|
| 570 |
+
target = str(alignededge)
|
| 571 |
+
lines.append(
|
| 572 |
+
" {from: "
|
| 573 |
+
+ nodeID
|
| 574 |
+
+ ", to: "
|
| 575 |
+
+ target
|
| 576 |
+
+ ', value: 1, style: "dash-line", color: "red"},'
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
lines[-1] = lines[-1][:-1]
|
| 580 |
+
lines.append("];")
|
| 581 |
+
return lines
|
| 582 |
+
|
| 583 |
+
def htmlOutput(
|
| 584 |
+
self,
|
| 585 |
+
outfile,
|
| 586 |
+
verbose: bool = False,
|
| 587 |
+
annotate_consensus: bool = True,
|
| 588 |
+
color_annotations: Dict[int, str] = None,
|
| 589 |
+
):
|
| 590 |
+
header = """
|
| 591 |
+
<!doctype html>
|
| 592 |
+
<html>
|
| 593 |
+
<head>
|
| 594 |
+
<title>POA Graph Alignment</title>
|
| 595 |
+
|
| 596 |
+
<script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
|
| 597 |
+
</head>
|
| 598 |
+
|
| 599 |
+
<body>
|
| 600 |
+
|
| 601 |
+
<div id="loadingProgress">0%</div>
|
| 602 |
+
|
| 603 |
+
<div id="mynetwork"></div>
|
| 604 |
+
|
| 605 |
+
<script type="text/javascript">
|
| 606 |
+
// create a network
|
| 607 |
+
"""
|
| 608 |
+
outfile.write(textwrap.dedent(header[1:]))
|
| 609 |
+
lines = self.jsOutput(
|
| 610 |
+
verbose=verbose,
|
| 611 |
+
annotate_consensus=annotate_consensus,
|
| 612 |
+
color_annotations=color_annotations,
|
| 613 |
+
)
|
| 614 |
+
for line in lines:
|
| 615 |
+
outfile.write(line + "\n")
|
| 616 |
+
footer = """
|
| 617 |
+
var container = document.getElementById('mynetwork');
|
| 618 |
+
var data= {
|
| 619 |
+
nodes: nodes,
|
| 620 |
+
edges: edges,
|
| 621 |
+
};
|
| 622 |
+
var options = {
|
| 623 |
+
width: '100%',
|
| 624 |
+
height: '800px',
|
| 625 |
+
physics: {
|
| 626 |
+
enabled: false,
|
| 627 |
+
stabilization: {
|
| 628 |
+
updateInterval: 10,
|
| 629 |
+
},
|
| 630 |
+
},
|
| 631 |
+
edges: {
|
| 632 |
+
color: {
|
| 633 |
+
inherit: false
|
| 634 |
+
}
|
| 635 |
+
},
|
| 636 |
+
layout: {
|
| 637 |
+
hierarchical: {
|
| 638 |
+
direction: "UD",
|
| 639 |
+
sortMethod: "directed",
|
| 640 |
+
shakeTowards: "roots",
|
| 641 |
+
levelSeparation: 150, // Adjust as needed
|
| 642 |
+
nodeSpacing: 800, // Adjust as needed
|
| 643 |
+
treeSpacing: 200, // Adjust as needed
|
| 644 |
+
parentCentralization: true,
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
};
|
| 648 |
+
var network = new vis.Network(container, data, options);
|
| 649 |
+
|
| 650 |
+
network.on('beforeDrawing', function(ctx) {
|
| 651 |
+
nodes.forEach(function(node) {
|
| 652 |
+
if (node.isConsensus) {
|
| 653 |
+
// Set the level of spine nodes to the bottom
|
| 654 |
+
network.body.data.nodes.update({
|
| 655 |
+
id: node.id,
|
| 656 |
+
level: 0 // Set level to 0 for spine nodes
|
| 657 |
+
});
|
| 658 |
+
}
|
| 659 |
+
});
|
| 660 |
+
});
|
| 661 |
+
|
| 662 |
+
network.on("stabilizationProgress", function (params) {
|
| 663 |
+
document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
|
| 664 |
+
});
|
| 665 |
+
network.once("stabilizationIterationsDone", function () {
|
| 666 |
+
document.getElementById("loadingProgress").innerText = "100%";
|
| 667 |
+
setTimeout(function () {
|
| 668 |
+
document.getElementById("loadingProgress").style.display = "none";
|
| 669 |
+
}, 500);
|
| 670 |
+
});
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
</script>
|
| 674 |
+
|
| 675 |
+
</body>
|
| 676 |
+
</html>
|
| 677 |
+
"""
|
| 678 |
+
outfile.write(textwrap.dedent(footer))
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def multi_consensus_response(self, abstention_threshold: Optional[float] = None, filter: bool = True):
|
| 682 |
+
self.toposort()
|
| 683 |
+
nodesInReverse = self.nodeidlist[::-1]
|
| 684 |
+
maxnodeID = self.end_id
|
| 685 |
+
nextInPath = [-1] * maxnodeID
|
| 686 |
+
scores = np.zeros(len(self.nodeidlist))
|
| 687 |
+
|
| 688 |
+
id_to_index = {node_id: index for index, node_id in enumerate(self.nodeidlist)}
|
| 689 |
+
index_to_id = {index: node_id for index, node_id in enumerate(self.nodeidlist)}
|
| 690 |
+
|
| 691 |
+
for nodeID in nodesInReverse:
|
| 692 |
+
bestWeightScoreEdges = [(-1, -1, None)]
|
| 693 |
+
for neighbourID in self.nodedict[nodeID].outEdges:
|
| 694 |
+
# print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
|
| 695 |
+
e = self.nodedict[nodeID].outEdges[neighbourID]
|
| 696 |
+
weightScoreEdge = (e.weight, scores[id_to_index[neighbourID]], neighbourID)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
if weightScoreEdge > bestWeightScoreEdges[0]:
|
| 700 |
+
bestWeightScoreEdges = [weightScoreEdge]
|
| 701 |
+
elif weightScoreEdge == bestWeightScoreEdges[0] and filter:
|
| 702 |
+
bestWeightScoreEdges.append(weightScoreEdge)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
scores[id_to_index[nodeID]] = sum(bestWeightScoreEdges[0][0:2])
|
| 706 |
+
if bestWeightScoreEdges[0][2] is not None:
|
| 707 |
+
nextInPath[id_to_index[nodeID]] = id_to_index[bestWeightScoreEdges[0][2]]
|
| 708 |
+
else:
|
| 709 |
+
nextInPath[id_to_index[nodeID]] = None
|
| 710 |
+
|
| 711 |
+
pos = np.argmax(scores)
|
| 712 |
+
path = []
|
| 713 |
+
text = []
|
| 714 |
+
labels = []
|
| 715 |
+
|
| 716 |
+
while pos is not None and pos > -1:
|
| 717 |
+
if abstention_threshold is not None and self.nodedict[index_to_id[pos]].variations:
|
| 718 |
+
if (
|
| 719 |
+
len(self.nodedict[index_to_id[pos]].labels) / self.num_sequences
|
| 720 |
+
>= abstention_threshold
|
| 721 |
+
):
|
| 722 |
+
path.append(index_to_id[pos])
|
| 723 |
+
labels.append(self.nodedict[index_to_id[pos]].labels)
|
| 724 |
+
text.append(self.nodedict[index_to_id[pos]].text)
|
| 725 |
+
else:
|
| 726 |
+
path.append(index_to_id[pos])
|
| 727 |
+
labels.append(self.nodedict[index_to_id[pos]].labels)
|
| 728 |
+
text.append(self.nodedict[index_to_id[pos]].text)
|
| 729 |
+
pos = nextInPath[pos]
|
| 730 |
+
|
| 731 |
+
# ignore END node
|
| 732 |
+
path = path[:-1]
|
| 733 |
+
# ignore END node
|
| 734 |
+
text = text[:-1]
|
| 735 |
+
# ignore START in text
|
| 736 |
+
text[0] = text[0].replace("START", "")
|
| 737 |
+
labels = labels[:-1]
|
| 738 |
+
|
| 739 |
+
return " ".join(text)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def consensus_response(
|
| 743 |
+
self, selection_threshold: Optional[float] = 0.5, api: str = "openai" , model: str = "gpt-4o-mini", task: str = "bio", **kwargs
|
| 744 |
+
) -> str:
|
| 745 |
+
self.toposort()
|
| 746 |
+
|
| 747 |
+
consensus_node_ids = self.consensus_node_ids
|
| 748 |
+
print(consensus_node_ids)
|
| 749 |
+
|
| 750 |
+
selected_node_ids = []
|
| 751 |
+
|
| 752 |
+
for node_id in consensus_node_ids:
|
| 753 |
+
if node_id == self.start_id or node_id == self.end_id:
|
| 754 |
+
continue
|
| 755 |
+
|
| 756 |
+
selected_node_ids.append(node_id)
|
| 757 |
+
|
| 758 |
+
for neighbor_id in self.nodedict[node_id].outEdges:
|
| 759 |
+
if neighbor_id in consensus_node_ids:
|
| 760 |
+
continue
|
| 761 |
+
|
| 762 |
+
if (
|
| 763 |
+
len(self.nodedict[neighbor_id].labels) / self.num_sequences
|
| 764 |
+
>= selection_threshold
|
| 765 |
+
):
|
| 766 |
+
selected_node_ids.append(neighbor_id)
|
| 767 |
+
|
| 768 |
+
text = " ".join([self.nodedict[node_id].text for node_id in selected_node_ids])
|
| 769 |
+
print(text)
|
| 770 |
+
cleaned_text = clean_up_text(text, task=task, api=api, model=model, **kwargs)
|
| 771 |
+
return cleaned_text
|
| 772 |
+
|
| 773 |
+
def save_to_pickle(self, filename):
|
| 774 |
+
with open(filename, "wb+") as f:
|
| 775 |
+
pickle.dump(self, f)
|
| 776 |
+
|
| 777 |
+
def refine_graph(
|
| 778 |
+
self,
|
| 779 |
+
verbose: bool = False,
|
| 780 |
+
save_intermediate_file: str = None,
|
| 781 |
+
final_merge: bool = True,
|
| 782 |
+
**kwargs,
|
| 783 |
+
):
|
| 784 |
+
self.merge_consensus_nodes(verbose=verbose)
|
| 785 |
+
|
| 786 |
+
if save_intermediate_file:
|
| 787 |
+
with open(save_intermediate_file, "w+") as f:
|
| 788 |
+
self.htmlOutput(f, annotate_consensus=False)
|
| 789 |
+
|
| 790 |
+
if not self.consensus_node_ids:
|
| 791 |
+
self.failed = True
|
| 792 |
+
return
|
| 793 |
+
|
| 794 |
+
else:
|
| 795 |
+
self.merge_divergent_paths(verbose=verbose, **kwargs)
|
| 796 |
+
|
| 797 |
+
if final_merge:
|
| 798 |
+
try:
|
| 799 |
+
self.merge_consensus_nodes(verbose=verbose)
|
| 800 |
+
except Exception as e:
|
| 801 |
+
print(e)
|
| 802 |
+
self.failed = True
|
src/text_poa_graph_utils.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import InferenceClient
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
TEXT_SIMILARITY_JUDGE_PROMPT = """
|
| 8 |
+
You are given two pieces of text. Your task is to determine whether they are semantically equivalent based solely on their factual content.
|
| 9 |
+
|
| 10 |
+
Here are the specific guidelines:
|
| 11 |
+
- Texts are equivalent if they convey the same core information or concept, regardless of wording or structure
|
| 12 |
+
- If one text has information that is a subset of the other text, then the texts are equivalent
|
| 13 |
+
- Focus ONLY on the essential claims, not on:
|
| 14 |
+
* Stylistic differences or tone
|
| 15 |
+
* Level of detail (if the core facts remain the same)
|
| 16 |
+
* Connotative differences between words
|
| 17 |
+
* Implied significance or emphasis
|
| 18 |
+
* Presentation order (if all key information is present in both)
|
| 19 |
+
- Minor additions of non-contradictory information should not make texts non-equivalent
|
| 20 |
+
- For ambiguous cases, prioritize the central claim or purpose of the text
|
| 21 |
+
|
| 22 |
+
Examples of equivalent pairs:
|
| 23 |
+
- "The meeting starts at 3pm" and "The 3 o'clock meeting will begin on time"
|
| 24 |
+
- "Research indicates a 15% increase" and "Studies show a fifteen percent rise"
|
| 25 |
+
- "was influential in the field" and "had a significant impact on the community"
|
| 26 |
+
|
| 27 |
+
Examples of non-equivalent pairs:
|
| 28 |
+
- "The project might be completed by Friday" and "The project will be finished by Friday"
|
| 29 |
+
- "Most experts agree on the approach" and "All experts support the approach"
|
| 30 |
+
|
| 31 |
+
Strictly follow these guidelines and return ONLY:
|
| 32 |
+
- equivalent
|
| 33 |
+
- not equivalent
|
| 34 |
+
"""
|
| 35 |
+
MATH_SIMILARITY_JUDGE_PROMPT = """
|
| 36 |
+
You are given two pieces of text from mathematical solutions. Your task is to determine whether the two solution segments are mathematically equivalent in their content, while allowing for stylistic variations.
|
| 37 |
+
|
| 38 |
+
Here are some important guidelines:
|
| 39 |
+
- Solutions should be considered equivalent if:
|
| 40 |
+
1. They communicate the same mathematical content/approach, even if word choice or phrasing differs
|
| 41 |
+
2. They contain the same key mathematical ideas, even if expressed differently
|
| 42 |
+
3. The same mathematical steps are described, even if using different words
|
| 43 |
+
4. They present the same final answer, regardless of wording style or formatting
|
| 44 |
+
|
| 45 |
+
- Allow for these variations while still considering solutions equivalent:
|
| 46 |
+
1. Stylistic differences ("we will" vs. "we'll" or "I'll")
|
| 47 |
+
2. Different levels of formality in the explanation
|
| 48 |
+
3. Minor rephrasing that preserves the core mathematical content
|
| 49 |
+
4. Use of synonyms or alternative mathematical terminology for the same concept
|
| 50 |
+
|
| 51 |
+
- Solutions are NOT equivalent if:
|
| 52 |
+
1. They use fundamentally different mathematical approaches
|
| 53 |
+
2. They work with different formulas or equations
|
| 54 |
+
3. They present different mathematical steps or operations
|
| 55 |
+
4. They reach different conclusions or answers
|
| 56 |
+
5. One contains substantial mathematical content that the other lacks
|
| 57 |
+
|
| 58 |
+
- When examining final answers, focus on mathematical equivalence rather than stylistic presentation
|
| 59 |
+
- For solution steps, maintain the core mathematical approach while allowing for rephrasing
|
| 60 |
+
|
| 61 |
+
Examples of solutions that SHOULD be considered equivalent:
|
| 62 |
+
- "We will systematically evaluate each possible grouping" and "We'll evaluate each grouping"
|
| 63 |
+
- "The answer is x = 5" and "Therefore, x equals 5"
|
| 64 |
+
- "Using the quadratic formula" and "Applying the quadratic formula"
|
| 65 |
+
|
| 66 |
+
Strictly follow the guidelines above.
|
| 67 |
+
Return your judgment in the following format. Do not include any other text:
|
| 68 |
+
- equivalent
|
| 69 |
+
- not equivalent
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def path_sim_llm(
|
| 73 |
+
path1_text: str,
|
| 74 |
+
path2_text: str,
|
| 75 |
+
api: str = "openai",
|
| 76 |
+
model: str = "gpt-4.1-mini",
|
| 77 |
+
verbose: bool = False,
|
| 78 |
+
domain: Optional[str] = "text",
|
| 79 |
+
custom_similarity_judge_prompt: str = None,
|
| 80 |
+
):
|
| 81 |
+
if api == "openai":
|
| 82 |
+
client = OpenAI()
|
| 83 |
+
elif api == "hf":
|
| 84 |
+
client = InferenceClient()
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Invalid API: {api}")
|
| 87 |
+
|
| 88 |
+
if domain == "text":
|
| 89 |
+
similarity_judge_prompt = (
|
| 90 |
+
f"{TEXT_SIMILARITY_JUDGE_PROMPT}\n\nText 1: {path1_text}\nText 2: {path2_text}"
|
| 91 |
+
)
|
| 92 |
+
elif domain == "math":
|
| 93 |
+
similarity_judge_prompt = (
|
| 94 |
+
f"{MATH_SIMILARITY_JUDGE_PROMPT}\n\nText 1: {path1_text}\nText 2: {path2_text}"
|
| 95 |
+
)
|
| 96 |
+
elif not domain and custom_similarity_judge_prompt:
|
| 97 |
+
similarity_judge_prompt = (
|
| 98 |
+
f"{custom_similarity_judge_prompt}\n\nText 1: {path1_text}\nText 2: {path2_text}"
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Invalid domain: {domain} and no custom similarity judge prompt provided")
|
| 102 |
+
|
| 103 |
+
completion = client.chat.completions.create(
|
| 104 |
+
model=model,
|
| 105 |
+
temperature=0,
|
| 106 |
+
messages=[
|
| 107 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 108 |
+
{"role": "user", "content": similarity_judge_prompt},
|
| 109 |
+
],
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
judgement = completion.choices[0].message.content.strip()
|
| 113 |
+
judgement = "".join(c for c in judgement if c.isalpha() or c == " ")
|
| 114 |
+
judgement = judgement.strip()
|
| 115 |
+
|
| 116 |
+
if verbose:
|
| 117 |
+
print(f"{path1_text} \nand \n{path2_text} \nare {judgement}")
|
| 118 |
+
|
| 119 |
+
if judgement == "equivalent":
|
| 120 |
+
return 1, completion.usage.prompt_tokens, completion.usage.completion_tokens
|
| 121 |
+
elif judgement == "not equivalent":
|
| 122 |
+
return 0, completion.usage.prompt_tokens, completion.usage.completion_tokens
|
| 123 |
+
else:
|
| 124 |
+
if verbose:
|
| 125 |
+
print(f"Invalid judgement: {judgement}")
|
| 126 |
+
return 0, completion.usage.prompt_tokens, completion.usage.completion_tokens
|
src/utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import InferenceClient
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def detect_abstain(text: str, api: str, model: str):
|
| 8 |
+
if api == "openai":
|
| 9 |
+
client = OpenAI()
|
| 10 |
+
elif api == "hf":
|
| 11 |
+
client = InferenceClient()
|
| 12 |
+
else:
|
| 13 |
+
raise ValueError(f"Invalid API: {api}")
|
| 14 |
+
|
| 15 |
+
detect_abstain_prompt = f"""
|
| 16 |
+
You are given a piece of text that is a part of a biography of an entity.
|
| 17 |
+
Text: {text}
|
| 18 |
+
|
| 19 |
+
If the text claims a lack of knowledge about the topic, return "Abstain".
|
| 20 |
+
Otherwise, return "Not abstain".
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
completion = client.chat.completions.create(
|
| 24 |
+
model=model,
|
| 25 |
+
messages=[
|
| 26 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 27 |
+
{"role": "user", "content": detect_abstain_prompt},
|
| 28 |
+
],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
return completion.choices[0].message.content.strip()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def calculate_factf1_at_k(
|
| 35 |
+
supported_facts: List[str], unsupported_facts: List[str], k: int
|
| 36 |
+
) -> float:
|
| 37 |
+
"""
|
| 38 |
+
Calculate the F1 score at k for supported and unsupported facts
|
| 39 |
+
"""
|
| 40 |
+
if len(supported_facts) == 0:
|
| 41 |
+
return 0
|
| 42 |
+
|
| 43 |
+
precision = len(supported_facts) / (len(supported_facts) + len(unsupported_facts))
|
| 44 |
+
recall = min(len(supported_facts) / k, 1)
|
| 45 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 46 |
+
return f1
|