Shahzaib98 commited on
Commit
d2ff6a7
·
verified ·
1 Parent(s): e6400bc

Upload 11 files

Browse files
src/__init__.py ADDED
File without changes
src/alignment.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from Jonathan Dursi
3
+ https://github.com/ljdursi/poapy
4
+ """
5
+
6
+ import numpy
7
+
8
+
9
+ class SeqGraphAlignment(object):
10
+ __matchscore = 1
11
+ __mismatchscore = -2
12
+ __gap = -1
13
+
14
+ def __init__(
15
+ self,
16
+ sequence,
17
+ graph,
18
+ fastMethod=True,
19
+ globalAlign=False,
20
+ matchscore=__matchscore,
21
+ mismatchscore=__mismatchscore,
22
+ gapscore=__gap,
23
+ *args,
24
+ **kwargs,
25
+ ):
26
+ self._mismatchscore = mismatchscore
27
+ self._matchscore = matchscore
28
+ self._gap = gapscore
29
+ self.sequence = sequence
30
+ self.graph = graph
31
+ self.stringidxs = None
32
+ self.nodeidxs = None
33
+ self.globalAlign = globalAlign
34
+ if fastMethod:
35
+ matches = self.alignStringToGraphFast(*args, **kwargs)
36
+ else:
37
+ matches = self.alignStringToGraphSimple(*args, **kwargs)
38
+ self.stringidxs, self.nodeidxs = matches
39
+
40
+ def alignmentStrings(self):
41
+ return "".join(
42
+ self.sequence[i] if i is not None else "-" for i in self.stringidxs
43
+ ), "".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs)
44
+
45
+ def matchscore(self, c1, c2):
46
+ if c1 == c2:
47
+ return self._matchscore
48
+ else:
49
+ return self._mismatchscore
50
+
51
+ def matchscoreVec(self, c, v):
52
+ return numpy.where(v == c, self._matchscore, self._mismatchscore)
53
+
54
+ def alignStringToGraphSimple(self):
55
+ """Align string to graph, following same approach as smith waterman
56
+ example"""
57
+ if type(self.sequence) is not str:
58
+ raise TypeError("Invalid Type")
59
+
60
+ nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx = (
61
+ self.initializeDynamicProgrammingData()
62
+ )
63
+
64
+ # Dynamic Programming
65
+ ni = self.graph.nodeiterator()
66
+ for i, node in enumerate(ni()):
67
+ pbase = node.text
68
+
69
+ for j, sbase in enumerate(self.sequence):
70
+ # add all candidates to a list, pick the best
71
+ candidates = [(scores[i + 1, j] + self._gap, i + 1, j, "INS")]
72
+ for predIndex in self.prevIndices(node, nodeIDtoIndex):
73
+ candidates += [
74
+ (scores[predIndex + 1, j + 1] + self._gap, predIndex + 1, j + 1, "DEL")
75
+ ]
76
+ candidates += [
77
+ (
78
+ scores[predIndex + 1, j] + self.matchscore(sbase, pbase),
79
+ predIndex + 1,
80
+ j,
81
+ "MATCH",
82
+ )
83
+ ]
84
+
85
+ (
86
+ scores[i + 1, j + 1],
87
+ backGrphIdx[i + 1, j + 1],
88
+ backStrIdx[i + 1, j + 1],
89
+ movetype,
90
+ ) = max(candidates)
91
+
92
+ if not self.globalAlign and scores[i + 1, j + 1] < 0:
93
+ scores[i + 1, j + 1] = 0.0
94
+ backGrphIdx[i + 1, j + 1] = -1
95
+ backStrIdx[i + 1, j + 1] = -1
96
+
97
+ return self.backtrack(scores, backStrIdx, backGrphIdx, nodeIndexToID)
98
+
99
+ def alignStringToGraphFast(self):
100
+ """Align string to graph - using numpy to vectorize across the string
101
+ at each iteration."""
102
+ if type(self.sequence) is not str:
103
+ raise TypeError("Invalid Type")
104
+
105
+ l2 = len(self.sequence)
106
+ seqvec = numpy.array(list(self.sequence))
107
+
108
+ nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx = (
109
+ self.initializeDynamicProgrammingData()
110
+ )
111
+ inserted = numpy.zeros((l2), dtype=bool)
112
+
113
+ # having the inner loop as a function improves performance
114
+ # can use Cython, etc on this for significant further improvements
115
+ # can't vectorize this since there's a loop-carried dependency
116
+ # along the string
117
+ def insertions(i, l2, scores, inserted):
118
+ inserted[:] = False
119
+ for j in range(l2):
120
+ insscore = scores[i + 1, j] + self._gap
121
+ if insscore >= scores[i + 1, j + 1]:
122
+ scores[i + 1, j + 1] = insscore
123
+ inserted[j] = True
124
+
125
+ # Dynamic Programming
126
+ ni = self.graph.nodeiterator()
127
+ for i, node in enumerate(ni()):
128
+ gbase = node.text
129
+ predecessors = self.prevIndices(node, nodeIDtoIndex)
130
+
131
+ # calculate all best deletions, matches in one go over all
132
+ # predecessors.
133
+
134
+ # First calculate for the first predecessor, over all string posns:
135
+ deletescore = scores[predecessors[0] + 1, 1:] + self._gap
136
+ bestdelete = numpy.zeros((l2), dtype=numpy.int32) + predecessors[0] + 1
137
+
138
+ matchpoints = self.matchscoreVec(gbase, seqvec)
139
+ matchscore = scores[predecessors[0] + 1, 0:-1] + matchpoints
140
+ bestmatch = numpy.zeros((l2), dtype=numpy.int32) + predecessors[0] + 1
141
+
142
+ # then, the remaining
143
+ for predecessor in predecessors[1:]:
144
+ newdeletescore = scores[predecessor + 1, 1:] + self._gap
145
+ bestdelete = numpy.where(newdeletescore > deletescore, predecessor + 1, bestdelete)
146
+ deletescore = numpy.maximum(newdeletescore, deletescore)
147
+
148
+ gbase = self.graph.nodeIdxToBase(predecessor)
149
+ matchpoints = self.matchscoreVec(gbase, seqvec)
150
+ newmatchscore = scores[predecessor + 1, 0:-1] + matchpoints
151
+ bestmatch = numpy.where(newmatchscore > matchscore, predecessor + 1, bestmatch)
152
+ matchscore = numpy.maximum(newmatchscore, matchscore)
153
+
154
+ # choose best options available of match, delete
155
+ deleted = deletescore >= matchscore
156
+ backGrphIdx[i + 1, 1:] = numpy.where(deleted, bestdelete, bestmatch)
157
+ backStrIdx[i + 1, 1:] = numpy.where(
158
+ deleted, numpy.arange(1, l2 + 1), numpy.arange(0, l2)
159
+ )
160
+ scores[i + 1, 1:] = numpy.where(deleted, deletescore, matchscore)
161
+
162
+ # insertions: updated in place, don't depend on predecessors
163
+ insertions(i, l2, scores, inserted)
164
+ backGrphIdx[i + 1, 1:] = numpy.where(inserted, i + 1, backGrphIdx[i + 1, 1:])
165
+ backStrIdx[i + 1, 1:] = numpy.where(inserted, numpy.arange(l2), backStrIdx[i + 1, 1:])
166
+
167
+ # if we're doing local alignment, don't let bad global alignment
168
+ # drag us negative
169
+ if not self.globalAlign:
170
+ backGrphIdx[i + 1, :] = numpy.where(scores[i + 1, :] > 0, backGrphIdx[i + 1, :], -1)
171
+ backStrIdx[i + 1, :] = numpy.where(scores[i + 1, :] > 0, backStrIdx[i + 1, :], -1)
172
+ scores[i + 1, :] = numpy.maximum(scores[i + 1, :], 0)
173
+
174
+ return self.backtrack(scores, backStrIdx, backGrphIdx, nodeIndexToID)
175
+
176
+ def prevIndices(self, node, nodeIDtoIndex):
177
+ """Return a list of the previous dynamic programming table indices
178
+ corresponding to predecessors of the current node."""
179
+ prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
180
+ # if no predecessors, point to just before the graph
181
+ if not prev:
182
+ prev = [-1]
183
+ return prev
184
+
185
+ def initializeDynamicProgrammingData(self):
186
+ """Initalize the dynamic programming tables:
187
+ - set up scores array
188
+ - set up backtracking array
189
+ - create index to Node ID table and vice versa"""
190
+ l1 = self.graph.nNodes
191
+ l2 = len(self.sequence)
192
+
193
+ nodeIDtoIndex = {}
194
+ nodeIndexToID = {-1: None}
195
+ # generate a dict of (nodeID) -> (index into nodelist (and thus matrix))
196
+ ni = self.graph.nodeiterator()
197
+ for index, node in enumerate(ni()):
198
+ nodeIDtoIndex[node.ID] = index
199
+ nodeIndexToID[index] = node.ID
200
+
201
+ # Dynamic Programming data structures; scores matrix and backtracking
202
+ # matrix
203
+ scores = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
204
+
205
+ # initialize insertion score
206
+ # if global align, penalty for starting at head != 0
207
+ if self.globalAlign:
208
+ scores[0, :] = numpy.arange(l2 + 1) * self._gap
209
+
210
+ ni = self.graph.nodeiterator()
211
+ for index, node in enumerate(ni()):
212
+ prevIdxs = self.prevIndices(node, nodeIDtoIndex)
213
+ best = scores[prevIdxs[0] + 1, 0]
214
+ for prevIdx in prevIdxs:
215
+ best = max(best, scores[prevIdx + 1, 0])
216
+ scores[index + 1, 0] = best + self._gap
217
+
218
+ # backtracking matrices
219
+ backStrIdx = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
220
+ backGrphIdx = numpy.zeros((l1 + 1, l2 + 1), dtype=numpy.int32)
221
+
222
+ return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx
223
+
224
+ def backtrack(self, scores, backStrIdx, backGrphIdx, nodeIndexToID):
225
+ """Backtrack through the scores and backtrack arrays.
226
+ Return a list of sequence indices and node IDs (not indices, which
227
+ depend on ordering)."""
228
+ besti, bestj = scores.shape
229
+ besti -= 1
230
+ bestj -= 1
231
+ if not self.globalAlign:
232
+ besti, bestj = numpy.argwhere(scores == numpy.amax(scores))[-1]
233
+ else:
234
+ ni = self.graph.nodeiterator()
235
+ # still have to find best final index to start from
236
+ terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
237
+ print(terminalIndices)
238
+ besti = terminalIndices[0] + 1
239
+ bestscore = scores[besti, bestj]
240
+ for i in terminalIndices[1:]:
241
+ score = scores[i + 1, bestj]
242
+ if score > bestscore:
243
+ bestscore, besti = score, i + 1
244
+
245
+ matches = []
246
+ strindexes = []
247
+ while (self.globalAlign or scores[besti, bestj] > 0) and (besti != 0 or bestj != 0):
248
+ nexti, nextj = backGrphIdx[besti, bestj], backStrIdx[besti, bestj]
249
+ curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
250
+
251
+ strindexes.insert(0, curstridx if nextj != bestj else None)
252
+ matches.insert(0, curnodeidx if nexti != besti else None)
253
+
254
+ besti, bestj = nexti, nextj
255
+
256
+ return strindexes, matches
src/generation_methods.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from src.generation_utils import (
4
+ extract_alternative_paths,
5
+ extract_context,
6
+ extract_equivalent_classes,
7
+ self_complete,
8
+ verify_correctness_pairwise,
9
+ )
10
+ from src.global_edit_utils import clean_up_text
11
+ from src.text_poa_graph import TextPOAGraph
12
+
13
+ """
14
+ Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold.
15
+ Only the primary variation of selected variable nodes are selected.
16
+ Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies).
17
+
18
+ Args:
19
+ text_poa_graph: The TextPOAGraph object to decode.
20
+ selection_threshold: The threshold for selecting nodes.
21
+ model: The model to use for decoding.
22
+
23
+ Returns:
24
+ A string of the decoded text.
25
+ """
26
+
27
+
28
+ def decode_consensus(
29
+ text_poa_graph: TextPOAGraph,
30
+ selection_threshold: Optional[float] = 0.5,
31
+ task: str = "bio",
32
+ verbose: bool = False,
33
+ **kwargs,
34
+ ) -> str:
35
+ if text_poa_graph.failed:
36
+ return "Abstain"
37
+
38
+ text_poa_graph.toposort()
39
+
40
+ consensus_node_ids = text_poa_graph.consensus_node_ids
41
+
42
+ selected_node_ids = []
43
+
44
+ for node_id in consensus_node_ids:
45
+ if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
46
+ continue
47
+
48
+ selected_node_ids.append(node_id)
49
+
50
+ for neighbor_id in text_poa_graph.nodedict[node_id].outEdges:
51
+ if neighbor_id in consensus_node_ids:
52
+ continue
53
+
54
+ if (
55
+ len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences
56
+ >= selection_threshold
57
+ ):
58
+ selected_node_ids.append(neighbor_id)
59
+
60
+ texts = []
61
+ for node_id in selected_node_ids:
62
+ if not text_poa_graph.nodedict[node_id].variations:
63
+ texts.append(text_poa_graph.nodedict[node_id].text)
64
+ else:
65
+ all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()]
66
+ all_texts.append(text_poa_graph.nodedict[node_id].text)
67
+ # select the variation that is longest
68
+ texts.append(max(all_texts, key=len))
69
+ text = " ".join(texts)
70
+ edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs)
71
+ if verbose:
72
+ return text, edited_text
73
+ else:
74
+ return edited_text
75
+
76
+
77
+ def decode_self_verified(
78
+ text_poa_graph: TextPOAGraph,
79
+ problem: str,
80
+ uncertainty_threshold: float = 0.6,
81
+ verification_api: str = "openai",
82
+ verification_model: str = "gpt-4o-mini",
83
+ grace_period: bool = True,
84
+ ):
85
+ high_uncertainty_nodes = []
86
+ for node_id in text_poa_graph.consensus_node_ids:
87
+ if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
88
+ continue
89
+
90
+ outgoing_edges = text_poa_graph.nodedict[node_id].outEdges
91
+ branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences
92
+
93
+ if branching_factor > uncertainty_threshold:
94
+ high_uncertainty_nodes.append(node_id)
95
+
96
+ selected_labels = list(text_poa_graph._seq_paths.keys())
97
+ masked_candidates = {}
98
+ uncertain_region = False
99
+ for label in selected_labels:
100
+ text = ""
101
+ for node_id in text_poa_graph._seq_paths[label]:
102
+ if uncertain_region:
103
+ text += f" *START_SEPARATOR*_{node_id} "
104
+ if node_id in high_uncertainty_nodes:
105
+ uncertain_region = True
106
+
107
+ if len(text_poa_graph.nodedict[node_id].variations) > 0:
108
+ text += text_poa_graph.nodedict[node_id].variations[label]
109
+ text += " "
110
+ else:
111
+ text += text_poa_graph.nodedict[node_id].text
112
+ text += " "
113
+
114
+ if uncertain_region and node_id not in high_uncertainty_nodes:
115
+ text += f" *END_SEPARATOR*_{node_id} "
116
+ uncertain_region = False
117
+ masked_candidates[label] = text
118
+
119
+ patch_start_node = None
120
+ uncertain_ids = []
121
+
122
+ # give a grace period for the first incorrect step
123
+ prev_step = {label: None for label in selected_labels}
124
+
125
+ for node_id in high_uncertainty_nodes:
126
+ uncertain_ids.append(node_id)
127
+ context_before = extract_context(text_poa_graph, node_id)
128
+ alternative_paths = extract_alternative_paths(text_poa_graph, node_id)
129
+ equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels)
130
+ new_labels = selected_labels.copy()
131
+
132
+ # Only do self-verifaction for labels from different sematically equivalent branches
133
+ if len(equivalent_classes) <= 1:
134
+ continue
135
+ i = 0
136
+ while i < len(equivalent_classes):
137
+ if i + 1 < len(equivalent_classes):
138
+ label_a = equivalent_classes[i][0]
139
+ label_b = equivalent_classes[i + 1][0]
140
+ full_a = context_before[label_a] + alternative_paths[label_a]
141
+ full_b = context_before[label_b] + alternative_paths[label_b]
142
+
143
+ score = verify_correctness_pairwise(
144
+ full_text_1=full_a,
145
+ full_text_2=full_b,
146
+ verification_model=verification_model,
147
+ problem=problem,
148
+ api=verification_api,
149
+ )
150
+ if float(score[0]) < 1.0:
151
+ print(f"Label {label_a} is incorrect at node {node_id}")
152
+ masked_candidates[label_a] = (
153
+ masked_candidates[label_a]
154
+ .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
155
+ .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
156
+ )
157
+ if not prev_step[label_a]:
158
+ prev_step[label_a] = True
159
+ if prev_step[label_a] and grace_period or not grace_period:
160
+ for label_i in equivalent_classes[i]:
161
+ new_labels.remove(label_i)
162
+ print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
163
+ if float(score[0]) == 1.0:
164
+ prev_step[label_a] = False
165
+ if float(score[1]) < 1.0:
166
+ print(f"Label {label_b} is incorrect at node {node_id}")
167
+ masked_candidates[label_b] = (
168
+ masked_candidates[label_b]
169
+ .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
170
+ .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
171
+ )
172
+ if not prev_step[label_b]:
173
+ prev_step[label_b] = True
174
+ if prev_step[label_b] and grace_period or not grace_period:
175
+ for label_i in equivalent_classes[i + 1]:
176
+ new_labels.remove(label_i)
177
+ print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
178
+ if float(score[1]) == 1.0:
179
+ prev_step[label_b] = False
180
+ i += 2
181
+ else:
182
+ break
183
+
184
+ if len(new_labels) == 0:
185
+ patch_start_node = node_id
186
+ break
187
+
188
+ selected_labels = new_labels.copy()
189
+
190
+ # These are the pruned approaches with masking
191
+ print(masked_candidates)
192
+ masked_approaches = "\n".join(
193
+ [
194
+ f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}"
195
+ for label in selected_labels
196
+ ]
197
+ )
198
+ # These are all approaches with masking
199
+ all_approaches = "\n".join(
200
+ [f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()]
201
+ )
202
+
203
+ default_prompt = f"""
204
+ Solve the following math problem with mathematical precision and clarity.
205
+
206
+ Problem: {problem}
207
+
208
+ Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*).
209
+ These sections may contain conceptual or computational errors.
210
+
211
+ There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*.
212
+ A verification step indicated that these steps are highly likely to contain errors.
213
+
214
+ Potential Approaches:
215
+ {masked_approaches}
216
+
217
+ Your task:
218
+ 1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses
219
+ If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors.
220
+ 2. Using the sections with special markers, identify potential errors.
221
+ 3. Develop a rigorous, step-by-step solution based on sound mathematical principles
222
+ 4. For uncertain regions:
223
+ - Verify each step using algebraic or numerical validation
224
+ - If correct, incorporate these steps with appropriate justification
225
+ - If incorrect, provide clear corrections with mathematical reasoning for your changes
226
+ 5. Follow a comparative approach, using the differences between approaches to identify potential errors.
227
+ 6. Do not blindly follow the approaches, but rather use them to identify potential errors.
228
+
229
+ Guidelines for your solution:
230
+ - Begin with a strategic overview of your chosen approach
231
+ - Present each mathematical step with clear notation and justification
232
+ - Pay special attention to areas that were previously marked uncertain
233
+
234
+ Conclude your solution with:
235
+ Therefore, the final answer is: $\\boxed{{answer}}$.
236
+
237
+ Solution:
238
+ """
239
+
240
+ patch_prompt = f"""
241
+ Solve the following mathematical problem with precision and clarity.
242
+
243
+ Problem: {problem}
244
+
245
+ You have been provided with several partial solution approaches that attempted to solve this problem.
246
+ None of these approaches are correct, but may contain valuable insights.
247
+ Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty.
248
+ A verification step indicated that these steps are likely to contain errors.
249
+
250
+ INSTRUCTIONS:
251
+ 1. Synthesize a correct solution using insights from the previous approaches
252
+ 2. Pay special attention to fixing the problematic areas marked by separators
253
+ 3. Develop your solution step-by-step, showing clear mathematical reasoning
254
+ 4. Focus especially on mathematical correctness in areas where previous solutions diverged
255
+ 5. Present your work in a logical, sequential manner suitable for an advanced reader
256
+
257
+ GUIDELINES FOR MATHEMATICAL RIGOR:
258
+ 1. MAINTAIN MATHEMATICAL RIGOR
259
+ - Verify that all mathematical operations follow from established principles and definitions
260
+ - Ensure dimensional consistency throughout calculations
261
+ - Check that algebraic manipulations preserve equality and do not introduce errors
262
+
263
+ 2. CONSIDER ALTERNATIVE PERSPECTIVES
264
+ - Even when approaches reach the same conclusion, examine their reasoning independently
265
+ - Look for more elegant or insightful connections that may be missed across all approaches
266
+ - Consider whether fundamental mathematical principles suggest a different path
267
+
268
+ 3. CRITICAL VALIDATION
269
+ - Test conclusions using known mathematical properties and relationships
270
+ - When possible, verify results using alternative methods
271
+ - Be especially cautious when all approaches agree on a result but use similar reasoning
272
+
273
+ 4. USE PRECISION IN CORRECTIONS
274
+ - When correcting uncertain regions, specify exactly what was incorrect and why
275
+ - Provide clear mathematical justification for any changes
276
+ - Ensure corrections align with standard mathematical principles and notations
277
+
278
+ Previous Approaches (for reference only):
279
+ {all_approaches}
280
+
281
+ Your Solution:
282
+ [Begin with a clear statement of your approach]
283
+ [Provide detailed mathematical steps]
284
+ [Ensure correct handling of complex mathematical operations]
285
+ [Verify your work at key points, especially in previously problematic areas]
286
+
287
+ Always conclude with:
288
+ Therefore, the final answer is: $\\boxed{{answer}}$
289
+ """
290
+
291
+ if patch_start_node is not None or len(masked_candidates.keys()) == 1:
292
+ print("None correct, patching")
293
+ prompt = patch_prompt
294
+ else:
295
+ prompt = default_prompt
296
+
297
+ return self_complete(
298
+ verification_prompt=prompt, verification_model=verification_model, api=verification_api
299
+ ), masked_candidates
src/generation_utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from huggingface_hub import InferenceClient
4
+ from openai import OpenAI
5
+ from together import Together
6
+
7
+ from src.text_poa_graph import TextPOAGraph
8
+
9
+
10
+ def extract_context(text_poa_graph, node_id):
11
+ """Extract context up to and including the specified node_id."""
12
+ contexts = {}
13
+ for label, path in text_poa_graph._seq_paths.items():
14
+ idx = path.index(node_id)
15
+ context = path[: idx + 1]
16
+ contexts[label] = " ".join(
17
+ text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
18
+ for nid in context
19
+ )
20
+ return contexts
21
+
22
+
23
+ def extract_alternative_paths(text_poa_graph: TextPOAGraph, node_id):
24
+ """Extract all alternative paths from this uncertainty point to the next consensus node."""
25
+ alternative_paths = {}
26
+ for label, path in text_poa_graph._seq_paths.items():
27
+ idx = path.index(node_id)
28
+ next_cn = None
29
+ for i in range(idx + 1, len(path)):
30
+ if path[i] in text_poa_graph.consensus_node_ids:
31
+ next_cn = path[i]
32
+ break
33
+
34
+ if next_cn:
35
+ next_cn_idx = path.index(next_cn)
36
+ alternative_segment = path[idx + 1 : next_cn_idx + 1]
37
+ else:
38
+ alternative_segment = []
39
+
40
+ alternative_paths[label] = " ".join(
41
+ text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
42
+ for nid in alternative_segment
43
+ )
44
+ return alternative_paths
45
+
46
+
47
+ def is_same_branch(text_poa_graph: TextPOAGraph, node_id, lable_1, label_2):
48
+ """Check if the next vaiable nodes for two sequences are the same after node_id."""
49
+ path_1 = text_poa_graph._seq_paths[lable_1]
50
+ path_2 = text_poa_graph._seq_paths[label_2]
51
+ idx_1 = path_1.index(node_id)
52
+ idx_2 = path_2.index(node_id)
53
+ return path_1[idx_1 + 1] == path_2[idx_2 + 1]
54
+
55
+
56
+ def extract_equivalent_classes(text_poa_graph: TextPOAGraph, node_id, selected_labels):
57
+ """Extract equivalent classes from the text POA graph."""
58
+ if not selected_labels:
59
+ return []
60
+
61
+ equivalent_classes = []
62
+ for label in selected_labels:
63
+ matched = False
64
+ for class_group in equivalent_classes:
65
+ if is_same_branch(text_poa_graph, node_id, class_group[0], label):
66
+ class_group.append(label)
67
+ matched = True
68
+ break
69
+ if not matched:
70
+ equivalent_classes.append([label])
71
+ return equivalent_classes
72
+
73
+
74
+ def verify_correctness_pairwise(
75
+ full_text_1: str, full_text_2: str, verification_model: str, problem: str, api: str = "openai"
76
+ ):
77
+ """Pairwise verification of two partial solution paths."""
78
+ if api == "openai":
79
+ client = OpenAI()
80
+ elif api == "hf":
81
+ client = InferenceClient()
82
+ elif api == "together":
83
+ client = Together()
84
+ else:
85
+ raise ValueError(f"Invalid API: {api}")
86
+
87
+ prompt = f"""
88
+ You will be given a problem and 2 partial solutions.
89
+ Your task is to use comparison as an EFFICIENCY TOOL to quickly identify potential errors.
90
+ You will be given guidelines to follow, and you will be penalized if you do not follow them.
91
+
92
+ Problem: {problem}
93
+
94
+ Partial Solution 1: {full_text_1}
95
+ Partial Solution 2: {full_text_2}
96
+
97
+ CRITICAL GUIDELINES:
98
+ - DO NOT penalize a solution for being incomplete or having missing steps
99
+ - DO NOT make a comparison of which solution is better
100
+ - DO NOT consider steps incorrect just because they differ between solutions
101
+ - DO NOT prematurely evaluate based on final answers or future steps
102
+ - DO NOT expect both solutions to be at the same stage of completion
103
+ - DO NOT consider a step incorrect just because it lacks sufficient detail or justification
104
+
105
+ KEY EFFICIENCY PRINCIPLE:
106
+ - Use agreement between solutions as evidence of correctness
107
+ - Use disagreement as a signal to investigate more deeply
108
+ - Only label a step as an error if it contains a specific mathematical mistake
109
+ - Incompleteness is not a mathematical error.
110
+
111
+ Here are the instructions for how to complete your task:
112
+
113
+ EFFICIENT VERIFICATION APPROACH:
114
+
115
+ 1. QUICK COMPARISON (Use this to focus your attention):
116
+ - Immediately identify where the solutions differ in approach or results
117
+ - Use these differences as "error hotspots" to prioritize your verification
118
+ - When solutions agree, you can generally assume that part is correct
119
+ - When solutions disagree, investigate those specific points deeply
120
+
121
+ 2. TARGETED VERIFICATION (Only where needed):
122
+ - Most important: Do not consider any incomplete steps as errors
123
+ - Focus your mathematical verification on the "hotspots" identified above
124
+ - Check mathematical validity only at points of difference or uncertainty
125
+ - Avoid line-by-line checking of steps where solutions agree
126
+ - For each potential error spot, verify if the mathematical reasoning is valid
127
+ - If an intermediate step is later corrected, do not penalize the solution for having the incorrect intermediate step
128
+
129
+ After your targeted verification, propose a score tuple (score_1, score_2):
130
+ - Score (1,1) if both partial solutions are valid
131
+ - Score (1,0) if only the first solution is valid
132
+ - Score (0,1) if only the second solution is valid
133
+ - Score (0,0) if both solutions are invalid
134
+
135
+ In case you score a solution as 0, you must give an explanation for each check below:
136
+ 3. FINAL CHECKS:
137
+ - If you score a solution as 0, you MUST identify the specific mathematical error.
138
+ - You must also double check the problem statement. Reconsider your score and determine if you have misinterpreted the problem statement.
139
+ - You must also check whether you have penalized a solution for being incomplete or having missing steps.
140
+
141
+ Before outputting your final score, you must answer these questions:
142
+ STOP! Did you give a score of 0 to a solution that was incomplete?
143
+ STOP! Did you penalize a solution for being incomplete or having missing steps?
144
+ STOP! Did you make a comparison of which solution is better?
145
+ STOP! Did you consider steps incorrect just because they differ between solutions?
146
+ STOP! Did you prematurely evaluate based on final answers?
147
+ STOP! Did you consider a step incorrect just because it lacks sufficient detail or justification?
148
+
149
+ Now give your final score:
150
+ Final score:
151
+ """
152
+ completion = client.chat.completions.create(
153
+ model=verification_model,
154
+ messages=[
155
+ {"role": "system", "content": "You are a helpful assistant."},
156
+ {"role": "user", "content": prompt},
157
+ ],
158
+ temperature=0.0,
159
+ )
160
+ response = completion.choices[0].message.content.strip()
161
+ print(full_text_1)
162
+ print(full_text_2)
163
+ print(f"Correctness score: {response} \n")
164
+ score_match = re.findall(r"\(\s*([01](?:\.0)?)\s*,\s*([01](?:\.0)?)\s*\)", response)
165
+ score = score_match[-1] if score_match else (0, 0)
166
+ return score
167
+
168
+
169
+ def self_complete(verification_prompt: str, verification_model: str, api: str = "openai"):
170
+ print(verification_prompt)
171
+ """Completetion method"""
172
+ if api == "openai":
173
+ client = OpenAI()
174
+ elif api == "hf":
175
+ client = InferenceClient()
176
+ elif api == "together":
177
+ client = Together()
178
+ else:
179
+ raise ValueError(f"Invalid API: {api}")
180
+
181
+ completion = client.chat.completions.create(
182
+ model=verification_model,
183
+ messages=[
184
+ {"role": "system", "content": "You are a helpful assistant."},
185
+ {"role": "user", "content": verification_prompt},
186
+ ],
187
+ temperature=0.0,
188
+ )
189
+ response = completion.choices[0].message.content.strip()
190
+ return response
src/global_edit_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ from openai import OpenAI
3
+
4
+ bio_prompt = """
5
+ You are given a piece of text that is a part of a biography of an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
6
+ Then, remove any redundant information.
7
+ Text: {text}
8
+
9
+ If this is not possible because the text is just a fragment of a sentence, return "Abstain".
10
+ If the text already claims a lack of knowledge about the topic, return "Abstain".
11
+ Only return the cleaned up text. Do not include any other text:
12
+ """
13
+
14
+ fp_prompt = """
15
+ You are given a piece of text that is a part of a false presupposition task which includes outputting a list of items.
16
+ This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
17
+ Then, remove any redundant information.
18
+ Text: {text}
19
+
20
+ The resulting list of items should be separated by semicolons with no other text.
21
+ If this list it not possible to generate, return "Abstain".
22
+ """
23
+
24
+ hist_prompt = """
25
+ You are given a piece of text that is a part of a historical event task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
26
+ Then, remove any redundant information.
27
+ Text: {text}
28
+
29
+ If this is not possible because the text is just a fragment of a sentence, return "Abstain".
30
+ If the text already claims a lack of knowledge about the topic, return "Abstain".
31
+ Only return the cleaned up text. Do not include any other text:
32
+ """
33
+
34
+ refs_prompt = """
35
+ You are given a piece of text that is a part of a reference task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
36
+ Then, remove any redundant information.
37
+ Text: {text}
38
+
39
+ If this is not possible because the text is just a fragment of a sentence, return "Abstain".
40
+ If the text already claims a lack of knowledge about the topic, return "Abstain".
41
+ Only return the cleaned up text. Do not include any other text:
42
+ """
43
+
44
+ gpqa_prompt = """
45
+ You are given a piece of text that is a part of a graduate level question answering task. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
46
+ Then, remove any redundant information.
47
+ Text: {text}
48
+ Only return the cleaned up text. Do not include any other text:
49
+ """
50
+
51
+ popqa_prompt = """
52
+ You are given a piece of text that is a part of a paragraph which details facts related to an entity. This text may contain some minor errors that make it incoherent as well as potentially redundant information. Your task is to fix the errors and make the text coherent.
53
+ Then, remove any redundant information.
54
+ Text: {text}
55
+
56
+ If this is not possible because the text is just a fragment of a sentence, return "Abstain".
57
+ If the text already claims a lack of knowledge about the topic, return "Abstain".
58
+ Only return the cleaned up text. Do not include any other text:
59
+ """
60
+ task_to_prompt = {
61
+ "bio": bio_prompt,
62
+ "fp": fp_prompt,
63
+ "hist": hist_prompt,
64
+ "refs": refs_prompt,
65
+ "gpqa": gpqa_prompt,
66
+ "popqa": popqa_prompt
67
+ }
68
+
69
+ '''
70
+ Cleans up disfluencies in the draft response in consensus decoding.
71
+
72
+ Args:
73
+ text: The text to clean up.
74
+ api: The API to use for cleaning up the text.
75
+ task: The task : biography, false presupposition, historical event, reference, graduate question answering, paragraph question answering.
76
+ model: The model to use for cleaning up the text.
77
+
78
+ Returns:
79
+ A string of the cleaned up text.
80
+ '''
81
+
82
+ def clean_up_text(text: str, api: str, task: str, model: str = "gpt-4.1-mini", **kwargs):
83
+ if api == "openai":
84
+ client = OpenAI()
85
+ elif api == "hf":
86
+ tokenizer = kwargs.get("tokenizer")
87
+ model = kwargs.get("hf_model")
88
+
89
+ if tokenizer is None or model is None:
90
+ raise ValueError("For 'hf', both 'tokenizer' and 'model' must be provided.")
91
+
92
+ clean_up_prompt = task_to_prompt[task].format(text=text)
93
+
94
+ messages = [{"role": "user", "content": clean_up_prompt}]
95
+ input_ids = tokenizer.apply_chat_template(
96
+ messages,
97
+ add_generation_prompt=True,
98
+ return_tensors="pt"
99
+ ).to(model.device)
100
+
101
+ terminators = [ tokenizer.eos_token_id, ]
102
+ outputs = model.generate(
103
+ input_ids,
104
+ max_new_tokens=500,
105
+ do_sample=False,
106
+ pad_token_id=tokenizer.eos_token_id,
107
+ eos_token_id=terminators,
108
+ )
109
+
110
+ return tokenizer.decode(
111
+ outputs[0][input_ids.shape[-1]:],
112
+ skip_special_tokens=True
113
+ ).strip()
114
+ else:
115
+ raise ValueError(f"Invalid API: {api}")
116
+
117
+ clean_up_prompt = task_to_prompt[task].format(text=text)
118
+
119
+ completion = client.chat.completions.create(
120
+ model=model,
121
+ messages=[
122
+ {"role": "system", "content": "You are a helpful assistant."},
123
+ {"role": "user", "content": clean_up_prompt},
124
+ ],
125
+ )
126
+
127
+ return completion.choices[0].message.content.strip()
src/new_alignment.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+
3
+
4
+ class ScoreParam:
5
+ def __init__(self, match, mismatch, gap_open, gap_extend):
6
+ self.match = match
7
+ self.mismatch = mismatch
8
+ self.gap_open = gap_open
9
+ self.gap_extend = gap_extend
10
+
11
+ def __str__(self):
12
+ return f"Match: {self.match}, Mismatch: {self.mismatch}, Gap Open: {self.gap_open}, Gap Extend: {self.gap_extend}"
13
+
14
+
15
+ class SeqGraphAlignment(object):
16
+ __default_score = ScoreParam(1, -3, -2, -1)
17
+
18
+ def __init__(
19
+ self,
20
+ sequence,
21
+ graph,
22
+ fastMethod=True,
23
+ globalAlign=False,
24
+ score_params=__default_score,
25
+ *args,
26
+ **kwargs,
27
+ ):
28
+ self.score = score_params
29
+ self.sequence = sequence
30
+ self.graph = graph
31
+ self.stringidxs = None
32
+ self.nodeidxs = None
33
+ self.globalAlign = globalAlign
34
+ if fastMethod:
35
+ matches = self.alignStringToGraphFast(*args, **kwargs)
36
+ else:
37
+ matches = self.alignStringToGraphSimple(*args, **kwargs)
38
+ self.stringidxs, self.nodeidxs = matches
39
+
40
+ def alignmentStrings(self):
41
+ return (
42
+ "".join(self.sequence[i] if i is not None else "-" for i in self.stringidxs),
43
+ "".join(self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs),
44
+ )
45
+
46
+ def matchscore(self, c1, c2):
47
+ if c1 == c2:
48
+ return self.score.match
49
+ else:
50
+ return self.score.mismatch
51
+
52
+ def matchscoreVec(self, c, v):
53
+ return numpy.where(v == c, self.score.match, self.score.mismatch)
54
+
55
+ def prevIndices(self, node, nodeIDtoIndex):
56
+ prev = [nodeIDtoIndex[predID] for predID in list(node.inEdges.keys())]
57
+ if not prev:
58
+ prev = [-1]
59
+ return prev
60
+
61
+ def initializeDynamicProgrammingData(self):
62
+ l1 = self.graph.nNodes
63
+ l2 = len(self.sequence)
64
+
65
+ nodeIDtoIndex = {}
66
+ nodeIndexToID = {-1: None}
67
+ ni = self.graph.nodeiterator()
68
+ for index, node in enumerate(ni()):
69
+ nodeIDtoIndex[node.ID] = index
70
+ nodeIndexToID[index] = node.ID
71
+
72
+ scores = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
73
+
74
+ if self.globalAlign:
75
+ # M[0, i] = -inf
76
+ scores[0, 0, :] = [
77
+ -1000000000 for i in range(l2+1)
78
+ ]
79
+ scores[0, 0, 0] = 0
80
+ # X[0, i] = gap_open + i * gap_extend
81
+ scores[1, 0, :] = [
82
+ self.score.gap_open + i * self.score.gap_extend for i in range(l2 + 1)
83
+ ]
84
+ scores[1, 0, 0] = -1000000000
85
+ # Y[0, i] = -inf
86
+ scores[2, 0, :] = [
87
+ -1000000000 for i in range(l2+1)
88
+ ]
89
+
90
+ ni = self.graph.nodeiterator()
91
+ # After topology sort, the predcessors will have index less than the current node
92
+ for index, node in enumerate(ni()):
93
+ scores[0, index + 1, 0] = -1000000000
94
+ scores[1, index + 1, 0] = -1000000000
95
+ prevIdxs = self.prevIndices(node, nodeIDtoIndex)
96
+ best = scores[2 ,prevIdxs[0] + 1, 0]
97
+ for prevIdx in prevIdxs:
98
+ best = max(best, scores[2, prevIdx + 1, 0])
99
+ # If we have no predecessors, we start the gap
100
+ if prevIdxs == [-1]:
101
+ scores[2, index + 1, 0] = self.score.gap_open + self.score.gap_extend
102
+ else:
103
+ scores[2, index + 1, 0] = best + self.score.gap_extend
104
+
105
+ # 3D Backtracking
106
+ backStrIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
107
+ backGrphIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
108
+ backMtxIdx = numpy.zeros((3, l1 + 1, l2 + 1), dtype=numpy.int32)
109
+
110
+ return nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx
111
+
112
+ def backtrack(self, scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID):
113
+ besti, bestj = scores.shape[1] - 1, scores.shape[2] - 1
114
+ #Storing best matrices for each [i,j]
115
+ scores_arr = numpy.array(scores)
116
+ max_m = numpy.argmax(scores_arr, axis=0)
117
+
118
+ if self.globalAlign:
119
+ ni = self.graph.nodeiterator()
120
+ # Finding the best node to start from
121
+ terminalIndices = [index for (index, node) in enumerate(ni()) if node.outDegree == 0]
122
+ print(terminalIndices)
123
+ besti = terminalIndices[0] + 1
124
+ bestscore = scores[max_m[besti, bestj], besti, bestj]
125
+ for i in terminalIndices[1:]:
126
+ score = scores[max_m[i + 1, bestj], i + 1, bestj]
127
+ if score > bestscore:
128
+ bestscore, besti = score, i + 1
129
+ bestm = max_m[besti, bestj]
130
+
131
+ matches = []
132
+ strindexes = []
133
+
134
+ while (besti != 0 or bestj != 0):
135
+ nextm, nexti, nextj, = backMtxIdx[bestm, besti, bestj], backGrphIdx[bestm, besti, bestj], backStrIdx[bestm, besti, bestj]
136
+ curstridx, curnodeidx = bestj - 1, nodeIndexToID[besti - 1]
137
+
138
+ if bestm == 0:
139
+ matches.insert(0, curnodeidx)
140
+ strindexes.insert(0, curstridx)
141
+ elif bestm == 1:
142
+ matches.insert(0, None)
143
+ strindexes.insert(0, curstridx)
144
+ else:
145
+ matches.insert(0, curnodeidx)
146
+ strindexes.insert(0, None)
147
+
148
+ bestm, besti, bestj = nextm, nexti, nextj
149
+
150
+ return strindexes, matches
src/new_text_alignment.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from difflib import SequenceMatcher
2
+
3
+ import numpy as np
4
+
5
+ from .new_alignment import ScoreParam, SeqGraphAlignment
6
+
7
+ PUNCTUATION_MARKS = [".", "!", "?", ",", ":", ";", "...", "(", ")"]
8
+
9
+ class TextSeqGraphAlignment(SeqGraphAlignment):
10
+ def __init__(
11
+ self,
12
+ text,
13
+ graph,
14
+ fastMethod=True,
15
+ globalAlign=True,
16
+ matchscore=1,
17
+ mismatchscore=-3,
18
+ gap_open=-2,
19
+ gap_extend=-1,
20
+ position_weight=0.1,
21
+ *args,
22
+ **kwargs,
23
+ ):
24
+ score_params = ScoreParam(
25
+ match=matchscore, mismatch=mismatchscore, gap_open=gap_open, gap_extend=gap_extend
26
+ )
27
+
28
+ if isinstance(text, str):
29
+ self.original_text = text
30
+ self.sequence = text.split()
31
+ else:
32
+ self.sequence = text
33
+ self.original_text = " ".join(text)
34
+ self.position_weight = position_weight
35
+
36
+ super().__init__(
37
+ self.sequence,
38
+ graph,
39
+ fastMethod,
40
+ globalAlign=globalAlign,
41
+ score_params=score_params,
42
+ *args,
43
+ **kwargs,
44
+ )
45
+
46
+ def string_similarity(self, s1, s2):
47
+ """Get edit-distance based similarity between two strings"""
48
+ return SequenceMatcher(None, s1, s2).ratio()
49
+
50
+ def matchscore(self, word1: str, word2: str) -> float:
51
+ """Enhanced scoring function that considers string similarity
52
+ and relative position"""
53
+ # Calculate basic string similarity
54
+ similarity = self.string_similarity(word1, word2)
55
+
56
+ # If words are very similar, treat as match
57
+ if similarity > 0.8: # Can tune this threshold
58
+ similarity = self.score.match
59
+ # For less similar words, scale score based on similarity
60
+ elif similarity > 0.5: # Can tune this threshold too
61
+ similarity = self.score.match * similarity
62
+ else:
63
+ similarity = self.score.mismatch
64
+ return similarity
65
+
66
+ # add weight if any punctuation mark is present
67
+ if any(char in word1 for char in PUNCTUATION_MARKS) or any(
68
+ char in word2 for char in PUNCTUATION_MARKS
69
+ ):
70
+ similarity = similarity * 1.5
71
+
72
+ return similarity
73
+
74
+ def alignmentStrings(self):
75
+ """Override to handle word-based alignment"""
76
+ aligned_seq = [self.sequence[i] if i is not None else "-" for i in self.stringidxs]
77
+ aligned_graph = [
78
+ self.graph.nodedict[j].text if j is not None else "-" for j in self.nodeidxs
79
+ ]
80
+ return " ".join(aligned_seq), " ".join(aligned_graph)
81
+
82
+ def alignStringToGraphFast(self):
83
+ if not isinstance(self.sequence, list):
84
+ raise TypeError("Sequence must be a list of words")
85
+
86
+ nodeIDtoIndex, nodeIndexToID, scores, backStrIdx, backGrphIdx, backMtxIdx = (
87
+ self.initializeDynamicProgrammingData()
88
+ )
89
+ # M: Match at last indices, X: Gap at last index of graph, Y: gap at last index of sequence
90
+ M, X, Y = 0, 1, 2
91
+
92
+ ni = self.graph.nodeiterator()
93
+ for i, node in enumerate(ni()):
94
+ gbase = node.text
95
+
96
+ for j, sbase in enumerate(self.sequence):
97
+ candidates_X , candidates_Y , candidates_M = [], [], []
98
+ candidates_X += [
99
+ (self.score.gap_open + self.score.gap_extend + scores[0, i + 1, j], i + 1, j, M),
100
+ (self.score.gap_extend + scores[1, i + 1, j], i + 1, j, X),
101
+ (self.score.gap_open + self.score.gap_extend + scores[2, i + 1, j], i + 1, j, Y)
102
+ ]
103
+ for predIndex in self.prevIndices(node, nodeIDtoIndex):
104
+ candidates_Y += [
105
+ (self.score.gap_open + self.score.gap_extend + scores[0, predIndex + 1, j + 1] , predIndex + 1, j + 1, M),
106
+ (self.score.gap_open + self.score.gap_extend + scores[1, predIndex + 1, j + 1] , predIndex + 1, j + 1, X),
107
+ (self.score.gap_extend + scores[2, predIndex + 1, j + 1] , predIndex + 1, j + 1, Y)
108
+ ]
109
+ candidates_M += [
110
+ (self.matchscore(sbase, gbase) + scores[0, predIndex + 1, j], predIndex + 1, j, M),
111
+ (self.matchscore(sbase, gbase) + scores[1, predIndex + 1, j], predIndex + 1, j, X),
112
+ (self.matchscore(sbase, gbase) + scores[2, predIndex + 1, j], predIndex + 1, j, Y)
113
+ ]
114
+
115
+ (
116
+ scores[0, i + 1, j + 1],
117
+ backGrphIdx[0, i + 1, j + 1],
118
+ backStrIdx[0, i + 1, j + 1],
119
+ backMtxIdx[0, i + 1, j + 1],
120
+ ) = max(candidates_M)
121
+ (
122
+ scores[1, i + 1, j + 1],
123
+ backGrphIdx[1, i + 1, j + 1],
124
+ backStrIdx[1, i + 1, j + 1],
125
+ backMtxIdx[1, i + 1, j + 1],
126
+ ) = max(candidates_X)
127
+ (
128
+ scores[2, i + 1, j + 1],
129
+ backGrphIdx[2, i + 1, j + 1],
130
+ backStrIdx[2, i + 1, j + 1],
131
+ backMtxIdx[2, i + 1, j + 1],
132
+ ) = max(candidates_Y)
133
+
134
+ return self.backtrack(scores, backStrIdx, backGrphIdx, backMtxIdx ,nodeIndexToID)
src/poa_graph.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from Jonathan Dursi
3
+ https://github.com/ljdursi/poapy
4
+ """
5
+
6
+ import collections
7
+ import textwrap
8
+ from typing import Dict, List, Optional, Union
9
+
10
+ import numpy
11
+
12
+ from .alignment import SeqGraphAlignment
13
+
14
+
15
+ class Node(object):
16
+ def __init__(self, nodeID: int = -1, text: str = ""):
17
+ self.ID = nodeID
18
+ self.text = text
19
+ self.inEdges = {}
20
+ self.outEdges = {}
21
+ self.alignedTo = []
22
+
23
+ def __str__(self):
24
+ return "(%d:%s)" % (self.ID, self.text)
25
+
26
+ def _add_edge(
27
+ self,
28
+ edgeset: Dict[int, "Node"],
29
+ neighbourID: int,
30
+ label: Union[int, List[int]],
31
+ from_neighbour: bool,
32
+ weight: int = 1,
33
+ ):
34
+ if neighbourID is None:
35
+ return
36
+ # already present? just update labels
37
+ # otherwise create appropriately-ordered edge and proceed
38
+ if neighbourID in edgeset:
39
+ edgeset[neighbourID].weight += weight
40
+ if isinstance(label, list):
41
+ edgeset[neighbourID].labels.extend(label)
42
+ else:
43
+ edgeset[neighbourID].labels.append(label)
44
+ # remove duplicates
45
+ edgeset[neighbourID].labels = list(set(edgeset[neighbourID].labels))
46
+ else:
47
+ if from_neighbour:
48
+ edge = Edge(outNodeID=neighbourID, inNodeID=self.ID, label=label, weight=weight)
49
+ else:
50
+ edge = Edge(outNodeID=self.ID, inNodeID=neighbourID, label=label, weight=weight)
51
+ edgeset[neighbourID] = edge
52
+
53
+ def addInEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
54
+ self._add_edge(self.inEdges, neighbourID, label, from_neighbour=True, weight=weight)
55
+
56
+ def addOutEdge(self, neighbourID: int, label: Optional[Union[int, List[int]]], weight: int = 1):
57
+ self._add_edge(self.outEdges, neighbourID, label, from_neighbour=False, weight=weight)
58
+
59
+ def nextNode(self, label: int):
60
+ """Returns the first (presumably only) outward neighbour
61
+ having the given edge label"""
62
+ nextID = None
63
+ for e in self.outEdges:
64
+ if label in self.outEdges[e].labels:
65
+ nextID = e
66
+ return nextID
67
+
68
+ @property
69
+ def inDegree(self):
70
+ return len(self.inEdges)
71
+
72
+ @property
73
+ def outDegree(self):
74
+ return len(self.outEdges)
75
+
76
+ @property
77
+ def weightedInDegree(self):
78
+ return sum(edge.weight for edge in self.inEdges.values())
79
+
80
+ @property
81
+ def weightedOutDegree(self):
82
+ return sum(edge.weight for edge in self.outEdges.values())
83
+
84
+ @property
85
+ def labels(self):
86
+ """Returns all the labels associated with an in-edge or an out edge."""
87
+ labelset = set([])
88
+ for e in list(self.inEdges.values()):
89
+ labelset = labelset.union(e.labels)
90
+ for e in list(self.outEdges.values()):
91
+ labelset = labelset.union(e.labels)
92
+ return list(labelset)
93
+
94
+
95
+ class Edge(object):
96
+ def __init__(
97
+ self,
98
+ inNodeID: int = -1,
99
+ outNodeID: int = -1,
100
+ label: Optional[Union[int, List[int]]] = None,
101
+ weight: int = 1,
102
+ ):
103
+ self.inNodeID = inNodeID
104
+ self.outNodeID = outNodeID
105
+
106
+ self.weight = weight
107
+
108
+ if label is None:
109
+ self.labels = []
110
+ elif isinstance(label, list):
111
+ self.labels = label
112
+ else:
113
+ self.labels = [label]
114
+
115
+ def addLabel(self, newlabel):
116
+ self.labels.append(newlabel)
117
+
118
+ def __str__(self):
119
+ nodestr = "(%d) -> (%d) " % (self.inNodeID, self.outNodeID)
120
+ if self.labels is None:
121
+ return nodestr
122
+ else:
123
+ return nodestr + self.labels.__str__()
124
+
125
+
126
+ class POAGraph(object):
127
+ def addUnmatchedSeq(self, seq, label: int = -1, updateSequences=True):
128
+ """Add a completely independant (sub)string to the graph,
129
+ and return node index to initial and final node"""
130
+ if seq is None:
131
+ return
132
+
133
+ firstID, lastID = None, None
134
+ neededSort = self.needsSort
135
+
136
+ for text in seq:
137
+ nodeID = self.addNode(text)
138
+ if firstID is None:
139
+ firstID = nodeID
140
+ if lastID is not None:
141
+ self.addEdge(lastID, nodeID, label)
142
+ lastID = nodeID
143
+
144
+ self._needsort = neededSort # no new order problems introduced
145
+ if updateSequences:
146
+ self._seqs.append(seq)
147
+ self._labels.append(label)
148
+ self._starts.append(firstID)
149
+ return firstID, lastID
150
+
151
+ def __init__(self, seq=None, label: Optional[Union[int, List[int]]] = None):
152
+ self._nextnodeID = 0
153
+ self._nnodes = 0
154
+ self._nedges = 0
155
+ self.nodedict = {}
156
+ self.nodeidlist = [] # allows a (partial) order to be imposed on the nodes
157
+ self._needsort = False
158
+ self._labels = []
159
+ self._seqs = []
160
+ self._starts = []
161
+
162
+ if seq is not None:
163
+ self.addUnmatchedSeq(seq, label)
164
+
165
+ def nodeIdxToBase(self, idx):
166
+ return self.nodedict[self.nodeidlist[idx]].text
167
+
168
+ def addNode(self, text):
169
+ nid = self._nextnodeID
170
+ newnode = Node(nid, text)
171
+ self.nodedict[nid] = newnode
172
+ self.nodeidlist.append(nid)
173
+ self._nnodes += 1
174
+ self._nextnodeID += 1
175
+ self._needsSort = True
176
+ return nid
177
+
178
+ def addEdge(self, start, end, label, weight: int = 1):
179
+ if start is None or end is None:
180
+ return
181
+
182
+ if start not in self.nodedict:
183
+ raise KeyError("addEdge: Start node not in graph: " + str(start))
184
+ if end not in self.nodedict:
185
+ raise KeyError("addEdge: End node not in graph: " + str(end))
186
+
187
+ oldNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
188
+
189
+ self.nodedict[start].addOutEdge(end, label, weight)
190
+ self.nodedict[end].addInEdge(start, label, weight)
191
+
192
+ newNodeEdges = self.nodedict[start].outDegree + self.nodedict[end].inDegree
193
+
194
+ if newNodeEdges != oldNodeEdges:
195
+ self._nedges += 1
196
+
197
+ self._needsSort = True
198
+ return
199
+
200
+ @property
201
+ def needsSort(self):
202
+ return self._needsort
203
+
204
+ @property
205
+ def nNodes(self):
206
+ return self._nnodes
207
+
208
+ @property
209
+ def nEdges(self):
210
+ return self._nedges
211
+
212
+ @property
213
+ def num_sequences(self):
214
+ return len(self._seqs)
215
+
216
+ def get_sequences(self):
217
+ return self._seqs
218
+
219
+ def _simplified_graph_rep(self):
220
+
221
+ node_to_pn = {}
222
+ pn_to_nodes = {}
223
+
224
+ # Find the mappings from nodes to pseudonodes
225
+ cur_pnid = 0
226
+ for _, node in self.nodedict.items():
227
+ if node.ID not in node_to_pn:
228
+ node_ids = [node.ID] + node.alignedTo
229
+ pn_to_nodes[cur_pnid] = node_ids
230
+ for nid in node_ids:
231
+ node_to_pn[nid] = cur_pnid
232
+ cur_pnid += 1
233
+
234
+ # create the pseudonodes
235
+ Pseudonode = collections.namedtuple(
236
+ "Pseudonode", ["pnode_id", "predecessors", "successors", "node_ids"]
237
+ )
238
+ pseudonodes = []
239
+
240
+ for pnid in range(cur_pnid):
241
+ nids, preds, succs = pn_to_nodes[pnid], [], []
242
+ for nid in nids:
243
+ node = self.nodedict[nid]
244
+ preds += [node_to_pn[inEdge.outNodeID] for _, inEdge in node.inEdges.items()]
245
+ succs += [node_to_pn[outEdge.inNodeID] for _, outEdge in node.outEdges.items()]
246
+
247
+ pn = Pseudonode(pnode_id=pnid, predecessors=preds, successors=succs, node_ids=nids)
248
+ pseudonodes.append(pn)
249
+
250
+ return pseudonodes
251
+
252
+ def toposort(self):
253
+ """Sorts node list so that all incoming edges come from nodes earlier in the list."""
254
+ sortedlist = []
255
+ completed = set([])
256
+
257
+ #
258
+ # The topological sort of this graph is complicated by the alignedTo edges;
259
+ # we want to nodes connected by such edges to remain near each other in the
260
+ # topological sort.
261
+ #
262
+ # Here we'll create a simple version of the graph that merges nodes that
263
+ # are alignedTo each other, performs the sort, and then decomposes the
264
+ # 'pseudonodes'.
265
+ #
266
+ # The need for this suggests that the way the graph is currently represented
267
+ # isn't quite right and needs some rethinking.
268
+ #
269
+
270
+ pseudonodes = self._simplified_graph_rep()
271
+
272
+ def dfs(start, complete, sortedlist):
273
+ stack, started = [start], set()
274
+ while stack:
275
+ pnodeID = stack.pop()
276
+
277
+ if pnodeID in complete:
278
+ continue
279
+
280
+ if pnodeID in started:
281
+ complete.add(pnodeID)
282
+ for nid in pseudonodes[pnodeID].node_ids:
283
+ sortedlist.insert(0, nid)
284
+ started.remove(pnodeID)
285
+ continue
286
+
287
+ successors = pseudonodes[pnodeID].successors
288
+ started.add(pnodeID)
289
+ stack.append(pnodeID)
290
+ stack.extend(successors)
291
+
292
+ while len(sortedlist) < self.nNodes:
293
+ found = None
294
+ for pnid in range(len(pseudonodes)):
295
+ if pnid not in completed and len(pseudonodes[pnid].predecessors) == 0:
296
+ found = pnid
297
+ break
298
+ assert found is not None
299
+ dfs(found, completed, sortedlist)
300
+
301
+ assert len(sortedlist) == self.nNodes
302
+ self.nodeidlist = sortedlist
303
+ self._needsSort = False
304
+ return
305
+
306
+ def testsort(self):
307
+ """Test the nodeidlist to make sure it is topologically sorted:
308
+ eg, all predecessors of a node preceed the node in the list"""
309
+ if self.nodeidlist is None:
310
+ return
311
+ seen_nodes = set()
312
+ for nodeidx in self.nodeidlist:
313
+ node = self.nodedict[nodeidx]
314
+ for in_neighbour in node.inEdges:
315
+ assert in_neighbour in seen_nodes
316
+ seen_nodes.add(nodeidx)
317
+ return
318
+
319
+ def nodeiterator(self):
320
+ if self.needsSort:
321
+ self.toposort()
322
+
323
+ def nodegenerator():
324
+ for nodeidx in self.nodeidlist:
325
+ yield self.nodedict[nodeidx]
326
+
327
+ return nodegenerator
328
+
329
+ def __str__(self):
330
+ selfstr = ""
331
+ ni = self.nodeiterator()
332
+ for node in ni():
333
+ selfstr += node.__str__() + "\n"
334
+ for outIdx in node.outEdges:
335
+ selfstr += " " + node.outEdges[outIdx].__str__() + "\n"
336
+ return selfstr
337
+
338
+ def incorporateSeqAlignment(self, alignment: SeqGraphAlignment, seq, label: int = -1):
339
+ """Incorporate a SeqGraphAlignment into the graph."""
340
+ newseq = alignment.sequence
341
+ stringidxs = alignment.stringidxs
342
+ nodeidxs = alignment.nodeidxs
343
+
344
+ firstID = None
345
+ headID = None
346
+ tailID = None
347
+
348
+ path = []
349
+ # head, tail of sequence may be unaligned; just add those into the
350
+ # graph directly
351
+ validstringidxs = [si for si in stringidxs if si is not None]
352
+ startSeqIdx, endSeqIdx = validstringidxs[0], validstringidxs[-1]
353
+ if startSeqIdx > 0:
354
+ firstID, headID = self.addUnmatchedSeq(
355
+ newseq[0:startSeqIdx], label, updateSequences=False
356
+ )
357
+ if endSeqIdx < len(newseq):
358
+ tailID, __ = self.addUnmatchedSeq(newseq[endSeqIdx + 1 :], label, updateSequences=False)
359
+
360
+ # now we march along the aligned part. For each text, we find or create
361
+ # a node in the graph:
362
+ # - if unmatched, the corresponding node is a new node
363
+ # - if matched:
364
+ # - if matched to a node with the same text, the node is that node
365
+ # - if matched to a node with a different text whch is in turn
366
+ # aligned to a node with the same text, that aligned node is
367
+ # the node
368
+ # - otherwise, we create a new node.
369
+ # In all cases, we create edges (or add labels) threading through the
370
+ # nodes.
371
+ for sindex, matchID in zip(stringidxs, nodeidxs):
372
+ if sindex is None:
373
+ continue
374
+ text = newseq[sindex]
375
+ if matchID is None:
376
+ nodeID = self.addNode(text)
377
+ elif self.nodedict[matchID].text == text:
378
+ nodeID = matchID
379
+ else:
380
+ otherAligns = self.nodedict[matchID].alignedTo
381
+ foundNode = None
382
+ for otherNodeID in otherAligns:
383
+ if self.nodedict[otherNodeID].text == text:
384
+ foundNode = otherNodeID
385
+ if foundNode is None:
386
+ nodeID = self.addNode(text)
387
+ self.nodedict[nodeID].alignedTo = [matchID] + otherAligns
388
+ for otherNodeID in [matchID] + otherAligns:
389
+ self.nodedict[otherNodeID].alignedTo.append(nodeID)
390
+ else:
391
+ nodeID = foundNode
392
+
393
+ self.addEdge(headID, nodeID, label)
394
+ headID = nodeID
395
+ if firstID is None:
396
+ firstID = headID
397
+
398
+ path.append(nodeID)
399
+
400
+ # finished the unaligned portion: now add an edge from the current headID to the tailID.
401
+ self.addEdge(headID, tailID, label)
402
+
403
+ # resort
404
+ self.toposort()
405
+
406
+ self._seqs.append(seq)
407
+ self._labels.append(label)
408
+ self._starts.append(firstID)
409
+ self._seq_paths[label] = path
410
+ return
411
+
412
+ def consensus(self, excludeLabels=None):
413
+ if excludeLabels is None:
414
+ excludeLabels = []
415
+
416
+ if self.needsSort:
417
+ self.toposort()
418
+
419
+ nodesInReverse = self.nodeidlist[::-1]
420
+ maxnodeID = max(nodesInReverse) + 1
421
+ nextInPath = [-1] * maxnodeID
422
+ scores = numpy.zeros((maxnodeID))
423
+
424
+ for nodeID in nodesInReverse:
425
+ bestWeightScoreEdge = (-1, -1, None)
426
+ for neighbourID in self.nodedict[nodeID].outEdges:
427
+ # print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
428
+ e = self.nodedict[nodeID].outEdges[neighbourID]
429
+ weightScoreEdge = (e.weight, scores[neighbourID], neighbourID)
430
+
431
+ if weightScoreEdge > bestWeightScoreEdge:
432
+ bestWeightScoreEdge = weightScoreEdge
433
+
434
+ scores[nodeID] = sum(bestWeightScoreEdge[0:2])
435
+ nextInPath[nodeID] = bestWeightScoreEdge[2]
436
+
437
+ pos = numpy.argmax(scores)
438
+ path = []
439
+ bases = []
440
+ labels = []
441
+ while pos is not None and pos > -1:
442
+ path.append(pos)
443
+ bases.append(self.nodedict[pos].text)
444
+ labels.append(self.nodedict[pos].labels)
445
+ pos = nextInPath[pos]
446
+
447
+ # ignore END node
448
+ path = path[:-1]
449
+ bases = bases[:-1]
450
+ labels = labels[:-1]
451
+ return path, bases, labels
452
+
453
+ def allConsenses(self, maxfraction=0.5):
454
+ allpaths = []
455
+ allbases = []
456
+ alllabels = []
457
+ exclusions = []
458
+
459
+ passno = 0
460
+ lastlen = 1000
461
+ maxpasses = 10
462
+
463
+ while len(exclusions) < len(self._labels) and lastlen >= 10 and passno < maxpasses:
464
+ path, bases, labellists = self.consensus(exclusions)
465
+ if len(path) > 0:
466
+ allpaths.append(path)
467
+ allbases.append(bases)
468
+ alllabels.append(labellists)
469
+
470
+ labelcounts = collections.defaultdict(int)
471
+ for ll in labellists:
472
+ for label in ll:
473
+ labelcounts[label] += 1
474
+
475
+ for label, seq in zip(self._labels, self._seqs):
476
+ if label in labelcounts and labelcounts[label] >= maxfraction * len(seq):
477
+ exclusions.append(label)
478
+
479
+ lastlen = len(path)
480
+ passno += 1
481
+
482
+ return list(zip(allpaths, allbases, alllabels))
483
+
484
+ def generateAlignmentStrings(self):
485
+ """Return a list of strings corresponding to the alignments in the graph"""
486
+
487
+ # Step 1: assign node IDs to columns in the output
488
+ # column_index[node.ID] is the position in the toposorted node list
489
+ # of the node itself, or the earliest node it is aligned to.
490
+ column_index = {}
491
+ current_column = 0
492
+
493
+ # go through nodes in toposort order
494
+ ni = self.nodeiterator()
495
+ for node in ni():
496
+ other_columns = [
497
+ column_index[other] for other in node.alignedTo if other in column_index
498
+ ]
499
+ if other_columns:
500
+ found_idx = min(other_columns)
501
+ else:
502
+ found_idx = current_column
503
+ current_column += 1
504
+
505
+ column_index[node.ID] = found_idx
506
+
507
+ ncolumns = current_column
508
+
509
+ # Step 2: given the column indexes, populate the strings
510
+ # corresponding to the sequences inserted in the graph
511
+ seqnames = []
512
+ alignstrings = []
513
+ for label, start in zip(self._labels, self._starts):
514
+ seqnames.append(label)
515
+ curnode_id = start
516
+ charlist = ["-"] * ncolumns
517
+ while curnode_id is not None:
518
+ node = self.nodedict[curnode_id]
519
+ charlist[column_index[curnode_id]] = node.text
520
+ curnode_id = node.nextNode(label)
521
+ alignstrings.append("".join(charlist))
522
+
523
+ # Step 3: Same as step 2, but with consensus sequences
524
+ consenses = self.allConsenses()
525
+ for i, consensus in enumerate(consenses):
526
+ seqnames.append("Consensus" + str(i))
527
+ charlist = ["-"] * ncolumns
528
+ for path, text in zip(consensus[0], consensus[1]):
529
+ charlist[column_index[path]] = text
530
+ alignstrings.append("".join(charlist))
531
+
532
+ return list(zip(seqnames, alignstrings))
533
+
534
+ def jsOutput(self, verbose: bool = False, annotate_consensus: bool = True):
535
+ """returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""
536
+
537
+ # get the consensus sequence, which we'll use as the "spine" of the
538
+ # graph
539
+ pathdict = {}
540
+ if annotate_consensus:
541
+ path, __, __ = self.consensus()
542
+ lines = ["var nodes = ["]
543
+
544
+ ni = self.nodeiterator()
545
+ count = 0
546
+ for node in ni():
547
+ line = " {id:" + str(node.ID) + ', label: "' + str(node.ID) + ": " + node.text + '"'
548
+ if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
549
+ line += (
550
+ ", x: "
551
+ + str(pathdict[node.ID])
552
+ + ", y: 0 , fixed: { x:true, y:false},"
553
+ + "color: '#7BE141', is_consensus:true},"
554
+ )
555
+ else:
556
+ line += "},"
557
+ lines.append(line)
558
+
559
+ lines[-1] = lines[-1][:-1]
560
+ lines.append("];")
561
+
562
+ lines.append(" ")
563
+
564
+ lines.append("var edges = [")
565
+ ni = self.nodeiterator()
566
+ for node in ni():
567
+ nodeID = str(node.ID)
568
+ for edge in node.outEdges:
569
+ target = str(edge)
570
+ weight = str(len(node.outEdges[edge].labels) + 1.5)
571
+ lines.append(
572
+ " {from: "
573
+ + nodeID
574
+ + ", to: "
575
+ + target
576
+ + ", value: "
577
+ + weight
578
+ + ", color: '#4b72b0', arrows: 'to'},"
579
+ )
580
+ if verbose:
581
+ for alignededge in node.alignedTo:
582
+ # These edges indicate alignment to different bases, and are
583
+ # undirected; thus make sure we only plot them once:
584
+ if node.ID > alignededge:
585
+ continue
586
+ target = str(alignededge)
587
+ lines.append(
588
+ " {from: "
589
+ + nodeID
590
+ + ", to: "
591
+ + target
592
+ + ', value: 1, style: "dash-line", color: "red"},'
593
+ )
594
+
595
+ lines[-1] = lines[-1][:-1]
596
+ lines.append("];")
597
+ return lines
598
+
599
+ def htmlOutput(self, outfile, verbose: bool = False, annotate_consensus: bool = True):
600
+ header = """
601
+ <!doctype html>
602
+ <html>
603
+ <head>
604
+ <title>POA Graph Alignment</title>
605
+
606
+ <script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
607
+ </head>
608
+
609
+ <body>
610
+
611
+ <div id="loadingProgress">0%</div>
612
+
613
+ <div id="mynetwork"></div>
614
+
615
+ <script type="text/javascript">
616
+ // create a network
617
+ """
618
+ outfile.write(textwrap.dedent(header[1:]))
619
+ lines = self.jsOutput(verbose=verbose, annotate_consensus=annotate_consensus)
620
+ for line in lines:
621
+ outfile.write(line + "\n")
622
+ footer = """
623
+ var container = document.getElementById('mynetwork');
624
+ var data= {
625
+ nodes: nodes,
626
+ edges: edges,
627
+ };
628
+ var options = {
629
+ width: '100%',
630
+ height: '800px',
631
+ physics: {
632
+ enabled: false,
633
+ stabilization: {
634
+ updateInterval: 10,
635
+ },
636
+ hierarchicalRepulsion: {
637
+ avoidOverlap: 0.9,
638
+ },
639
+ },
640
+ edges: {
641
+ color: {
642
+ inherit: false
643
+ }
644
+ },
645
+ layout: {
646
+ hierarchical: {
647
+ direction: "UD",
648
+ sortMethod: "directed",
649
+ shakeTowards: "roots",
650
+ levelSeparation: 150, // Adjust as needed
651
+ nodeSpacing: 100, // Adjust as needed
652
+ treeSpacing: 200, // Adjust as needed
653
+ parentCentralization: true,
654
+ }
655
+ }
656
+ };
657
+ var network = new vis.Network(container, data, options);
658
+
659
+ network.on('beforeDrawing', function(ctx) {
660
+ nodes.forEach(function(node) {
661
+ if (node.isConsensus) {
662
+ // Set the level of spine nodes to the bottom
663
+ network.body.data.nodes.update({
664
+ id: node.id,
665
+ level: 0 // Set level to 0 for spine nodes
666
+ });
667
+ }
668
+ });
669
+ });
670
+
671
+ network.on("stabilizationProgress", function (params) {
672
+ document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
673
+ });
674
+ network.once("stabilizationIterationsDone", function () {
675
+ document.getElementById("loadingProgress").innerText = "100%";
676
+ setTimeout(function () {
677
+ document.getElementById("loadingProgress").style.display = "none";
678
+ }, 500);
679
+ });
680
+ </script>
681
+
682
+ </body>
683
+ </html>
684
+ """
685
+ outfile.write(textwrap.dedent(footer))
src/text_poa_graph.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced version of POAGraph for text alignment
3
+ """
4
+
5
+ import pickle
6
+ import textwrap
7
+ from typing import Dict, Optional
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from src.text_poa_graph_utils import path_sim_llm
13
+ from src.global_edit_utils import clean_up_text
14
+
15
+ from .new_text_alignment import TextSeqGraphAlignment
16
+ from .poa_graph import Node, POAGraph
17
+
18
+
19
+ class TextNode(Node):
20
+ def __init__(self, nodeID=-1, text=""):
21
+ super().__init__(nodeID, text)
22
+ self.variations = {} # Track alternate phrasings
23
+ self.sequences = [] # Track sequences that contain this node
24
+ self.influenceScore = 0
25
+ self.num_tokens_used = 0
26
+
27
+ def add_variation(self, text, sequence_id):
28
+ self.variations[sequence_id] = text
29
+
30
+ @property
31
+ def is_stable(self):
32
+ """A node is stable if it appears frequently enough relative to total sequences"""
33
+ return self.frequency >= self.graph.stability_threshold
34
+
35
+
36
+ class TextPOAGraph(POAGraph):
37
+ def __init__(self, text=None, label=-1):
38
+ self.consensus_node_ids = []
39
+ self._seq_paths = {}
40
+ self.end_id = -1
41
+ self.start_id = -1
42
+ self.failed = False
43
+ self.num_input_tokens_used = 0
44
+ self.num_output_tokens_used = 0
45
+ super().__init__(text, label)
46
+
47
+ def addNode(self, text):
48
+ """Override to use TextNode"""
49
+ nid = self._nextnodeID
50
+ newnode = TextNode(nid, text)
51
+ self.nodedict[nid] = newnode
52
+ self.nodeidlist.append(nid)
53
+ self._nnodes += 1
54
+ self._nextnodeID += 1
55
+ self._needsSort = True
56
+ return nid
57
+
58
+ def addUnmatchedSeq(self, text, label=-1, updateSequences=True):
59
+ """Modified to handle text sequences"""
60
+ if text is None:
61
+ return
62
+
63
+ # Handle both string and list input
64
+ if isinstance(text, str):
65
+ words = text.split()
66
+ else:
67
+ words = text
68
+
69
+ firstID, lastID = None, None
70
+ neededSort = self.needsSort
71
+
72
+ path = []
73
+ for word in words:
74
+ nodeID = self.addNode(word)
75
+ if firstID is None:
76
+ firstID = nodeID
77
+ if lastID is not None:
78
+ self.addEdge(lastID, nodeID, label=label)
79
+ lastID = nodeID
80
+ path.append(nodeID)
81
+
82
+ self._needsort = neededSort
83
+ if updateSequences:
84
+ self._seqs.append(words)
85
+ self._labels.append(label)
86
+ self._starts.append(firstID)
87
+ self._seq_paths[label] = path
88
+
89
+ return firstID, lastID
90
+
91
+ def add_text(self, text, label=-1):
92
+ """Main method to add new text to the alignment"""
93
+ if len(self._seqs) == 0:
94
+ # First sequence - just add it
95
+ self.addUnmatchedSeq(text, label)
96
+ else:
97
+ # Align to existing graph
98
+ alignment = TextSeqGraphAlignment(
99
+ text, self, matchscore=2, mismatchscore=-1, gapscore=-2
100
+ )
101
+ self.incorporateSeqAlignment(alignment, text, label)
102
+
103
+ # Update node frequencies
104
+ self._update_frequencies()
105
+
106
+ def removeNode(self, nodeID):
107
+ """Override to handle text nodes"""
108
+ node = self.nodedict[nodeID]
109
+ if node is None:
110
+ return
111
+
112
+ # Remove all edges to this node
113
+ out_edges = node.outEdges.copy()
114
+ in_edges = node.inEdges.copy()
115
+
116
+ for edge in out_edges:
117
+ self.removeEdge(node.ID, edge)
118
+ for edge in in_edges:
119
+ self.removeEdge(edge, node.ID)
120
+
121
+ # Remove from graph
122
+ del self.nodedict[nodeID]
123
+ self.nodeidlist.remove(nodeID)
124
+
125
+ for path in self._seq_paths.values():
126
+ if nodeID in path:
127
+ path.remove(nodeID)
128
+
129
+ self._nnodes -= 1
130
+ self._needsSort = True
131
+
132
+ def removeEdge(self, nodeID1, nodeID2):
133
+ """Override to handle text nodes"""
134
+ node1 = self.nodedict[nodeID1]
135
+ node2 = self.nodedict[nodeID2]
136
+
137
+ if node1 is None or node2 is None:
138
+ return
139
+
140
+ # Remove from graph
141
+ del node1.outEdges[nodeID2]
142
+ del node2.inEdges[nodeID1]
143
+
144
+ def merge_consensus_nodes(self, verbose: bool = False):
145
+ self.toposort()
146
+ # reset consensus node ids
147
+ self.consensus_node_ids = []
148
+ nodes = list(self.nodeiterator()())
149
+ consensus_segments = []
150
+ i = 0
151
+ while i < len(nodes):
152
+ node = nodes[i]
153
+ out_weight = sum(e.weight for e in node.outEdges.values())
154
+ in_weight = sum(e.weight for e in node.inEdges.values())
155
+
156
+ if out_weight in [0, self.num_sequences] and in_weight in [0, self.num_sequences]:
157
+ consensus_segment = [(node.ID, node.text)]
158
+ next_node = node
159
+ while (i + 1) < len(nodes) and len(next_node.outEdges) == 1:
160
+ next_node = nodes[i + 1]
161
+ next_out_weight = sum(e.weight for e in next_node.outEdges.values())
162
+ next_in_weight = sum(e.weight for e in next_node.inEdges.values())
163
+
164
+ if (
165
+ next_out_weight != self.num_sequences
166
+ or next_in_weight != self.num_sequences
167
+ ):
168
+ break
169
+
170
+ consensus_segment.append((next_node.ID, next_node.text))
171
+ i += 1
172
+ consensus_segments.append(consensus_segment)
173
+ i += 1
174
+ # merge consensus nodes into a single node
175
+ for segment in consensus_segments:
176
+ if len(segment) == 1:
177
+ self.consensus_node_ids.append(segment[0][0])
178
+ continue
179
+ merged_text = " ".join([text for _, text in segment])
180
+ first_node_id = segment[0][0]
181
+ last_node_id = segment[-1][0]
182
+
183
+ self.nodedict[last_node_id].text = merged_text
184
+ self.consensus_node_ids.append(last_node_id)
185
+
186
+ # attach all incoming edges to first node to last node
187
+ for id, edge in self.nodedict[first_node_id].inEdges.items():
188
+ weight = edge.weight
189
+ for _ in range(weight):
190
+ self.addEdge(id, last_node_id, label=edge.labels)
191
+
192
+ # delete all nodes except last node
193
+ for node_id, _ in segment[:-1]:
194
+ self.removeNode(node_id)
195
+
196
+
197
+
198
+ if verbose:
199
+ print(self.consensus_node_ids)
200
+
201
+ """
202
+ find all paths between start_node_id and end_node_id from original sequences
203
+ return a list of dictionaries with the following keys:
204
+ - path: list of node ids in the path (excluding start and including end)
205
+ - text: text of the path (excluding start and end)
206
+ - weight: minimal edge weight across all edges in the path
207
+ - labels: intersection of all edge labels in the path
208
+ """
209
+
210
+ def find_paths_between(self, start_node_id: int, end_node_id: int):
211
+ # find all paths between start_node_id and end_node_id from original sequences
212
+ path_dicts = []
213
+
214
+ # keep track of visited paths to avoid duplicates
215
+ visited_paths = set()
216
+
217
+ for _, path in self._seq_paths.items():
218
+ start_index = path.index(start_node_id) if start_node_id in path else None
219
+ end_index = path.index(end_node_id) if end_node_id in path else None
220
+
221
+ # print(start_index, end_index)
222
+ # print(path)
223
+
224
+ if (
225
+ start_index is not None
226
+ and end_index is not None
227
+ and end_index - start_index > 1
228
+ and tuple(path[start_index + 1 : end_index + 1]) not in visited_paths
229
+ ):
230
+ # intersection of all edge labels in the path
231
+ path_labels = set.intersection(
232
+ *[
233
+ set(self.nodedict[next_node_id].inEdges[node_id].labels)
234
+ for node_id, next_node_id in zip(
235
+ path[start_index:end_index], path[start_index + 1 : end_index + 1]
236
+ )
237
+ ]
238
+ )
239
+ path_weight = len(path_labels)
240
+ path_dicts.append(
241
+ {
242
+
243
+ "path": path[start_index + 1 : end_index + 1],
244
+ "body_text": " ".join(
245
+ [
246
+ self.nodedict[node_id].text
247
+ for node_id in path[start_index + 1 : end_index]
248
+ ]
249
+ ),
250
+ "begin_text": self.nodedict[path[start_index]].text,
251
+ "end_text": self.nodedict[path[end_index]].text,
252
+ "weight": path_weight,
253
+ "labels": path_labels,
254
+ }
255
+ )
256
+ visited_paths.add(tuple(path[start_index + 1 : end_index + 1]))
257
+
258
+ return path_dicts
259
+
260
+ def _follow_path(self, start_id):
261
+ """Follow all possible paths from a node"""
262
+ paths = []
263
+ visited = set()
264
+
265
+ def dfs(node_id, current_path):
266
+ if node_id in visited:
267
+ return
268
+ visited.add(node_id)
269
+ node = self.nodedict[node_id]
270
+
271
+ if not node.outEdges:
272
+ paths.append(current_path + [node_id])
273
+ return
274
+
275
+ for next_id in node.outEdges:
276
+ dfs(next_id, current_path + [node_id])
277
+
278
+ dfs(start_id, [])
279
+ return paths
280
+
281
+ def merge_paths_between(
282
+ self,
283
+ start_node_id: int,
284
+ end_node_id: int,
285
+ path_sim_type: str = "llm",
286
+ verbose: bool = False,
287
+ **kwargs,
288
+ ):
289
+ path_dicts = self.find_paths_between(start_node_id, end_node_id)
290
+
291
+ if path_sim_type == "llm":
292
+ api = kwargs.get("api", "openai")
293
+ model = kwargs.get("model", "gpt-4o-mini")
294
+ domain = kwargs.get("domain", None)
295
+ similarity_judge_prompt = kwargs.get("similarity_judge_prompt", None)
296
+
297
+ def path_sim_func(path1_text, path2_text):
298
+ return path_sim_llm(
299
+ path1_text,
300
+ path2_text,
301
+ api=api,
302
+ model=model,
303
+ domain=domain,
304
+ custom_similarity_judge_prompt=similarity_judge_prompt,
305
+ )
306
+
307
+ elif path_sim_type == "cosine":
308
+ pass
309
+ # embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
310
+ # threshold = kwargs.get("threshold", 0.9)
311
+ # path_sim_func = path_sim_cosine(embedding_model, threshold)
312
+ else:
313
+ raise ValueError(f"Invalid path similarity type: {path_sim_type}")
314
+
315
+ # merge paths based on semantic similarity
316
+ path_equivalence_classes = {}
317
+ class_count = 0
318
+
319
+ for path_dict in path_dicts:
320
+ if verbose:
321
+ print(path_dict)
322
+ found_class = False
323
+ for _, eq_class in path_equivalence_classes.items():
324
+ # check if path dict is already in an equivalence class
325
+ path1_text = (
326
+ path_dict["begin_text"]
327
+ + " "
328
+ + path_dict["body_text"]
329
+ + " "
330
+ + path_dict["end_text"]
331
+ )
332
+ path2_text = (
333
+ eq_class[0]["begin_text"]
334
+ + " "
335
+ + eq_class[0]["body_text"]
336
+ + " "
337
+ + eq_class[0]["end_text"]
338
+ )
339
+
340
+ judgement, num_input_tokens, num_output_tokens = path_sim_func(
341
+ path1_text, path2_text
342
+ )
343
+ self.num_input_tokens_used += num_input_tokens
344
+ self.num_output_tokens_used += num_output_tokens
345
+ if judgement:
346
+ eq_class.append(path_dict)
347
+ found_class = True
348
+ break
349
+ if not found_class:
350
+ class_count += 1
351
+ path_equivalence_classes[class_count] = [path_dict]
352
+
353
+ nodes_to_remove = set() # Track nodes to remove
354
+ for _, eq_class in path_equivalence_classes.items():
355
+ path_dict = eq_class[0]
356
+
357
+ if verbose:
358
+ print(eq_class)
359
+ # add new node with merged text
360
+ new_node_id = self.addNode(path_dict["body_text"])
361
+ for sequence_id in path_dict["labels"]:
362
+ self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]
363
+
364
+ # collect nodes to remove from first path
365
+ nodes_to_remove.update(path_dict["path"][:-1])
366
+
367
+ # process data regarding weights and labels
368
+ labels = list(path_dict["labels"])
369
+ weight = path_dict["weight"]
370
+ self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)
371
+
372
+ # Updated seq_paths for all labels to include new_node betwwen start_node and end_node
373
+ for label in labels:
374
+ index = self._seq_paths[label].index(start_node_id)
375
+ if (
376
+ index + 1 < len(self._seq_paths[label])
377
+ and self._seq_paths[label][index + 1] != new_node_id
378
+ ):
379
+ self._seq_paths[label].insert(index + 1, new_node_id)
380
+
381
+ self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)
382
+
383
+ self.nodedict[new_node_id].sequences = labels
384
+ # process additional paths
385
+ for path_dict in eq_class[1:]:
386
+ for sequence_id in path_dict["labels"]:
387
+ self.nodedict[new_node_id].variations[sequence_id] = path_dict["body_text"]
388
+ nodes_to_remove.update(path_dict["path"][:-1])
389
+
390
+ # copy incoming edges to new node
391
+ labels = list(path_dict["labels"])
392
+ weight = path_dict["weight"]
393
+ self.addEdge(start_node_id, new_node_id, label=labels, weight=weight)
394
+
395
+ # Updated seq_paths for all labels to include new_node betwwen start_node and end_node
396
+ for label in labels:
397
+ index = self._seq_paths[label].index(start_node_id)
398
+ if (
399
+ index + 1 < len(self._seq_paths[label])
400
+ and self._seq_paths[label][index + 1] != new_node_id
401
+ ):
402
+ self._seq_paths[label].insert(index + 1, new_node_id)
403
+
404
+ self.addEdge(new_node_id, end_node_id, label=labels, weight=weight)
405
+ self.nodedict[new_node_id].sequences.extend(labels)
406
+
407
+ self.nodedict[new_node_id].sequences = list(set(self.nodedict[new_node_id].sequences))
408
+
409
+ # Remove all collected nodes after processing
410
+ for node_id in nodes_to_remove:
411
+ if node_id in self.nodedict:
412
+ if verbose:
413
+ print(f"Removing node {node_id}")
414
+ self.removeNode(node_id)
415
+
416
+ def merge_divergent_paths(self, path_sim_type: str = "llm", verbose: bool = False, **kwargs):
417
+ # add dummy end node to the end of the graph
418
+ if not self.consensus_node_ids:
419
+ self.merge_consensus_nodes(verbose=verbose)
420
+
421
+ self.toposort()
422
+
423
+ if self.start_id == -1:
424
+ if verbose:
425
+ print("Adding start node")
426
+ self.start_id = self.addNode(text="START")
427
+ self._nextnodeID += 1
428
+ self.consensus_node_ids.insert(0, self.start_id)
429
+
430
+ for label, path in self._seq_paths.items():
431
+ self.addEdge(self.start_id, path[0], label=label, weight=1)
432
+ path.insert(0, self.start_id)
433
+
434
+ if self.end_id == -1:
435
+ if verbose:
436
+ print("Adding end node")
437
+ self.end_id = self.addNode(text="END")
438
+ self._nextnodeID += 1
439
+ self.consensus_node_ids = self.consensus_node_ids + [self.end_id]
440
+
441
+ for label, path in self._seq_paths.items():
442
+ self.addEdge(path[-1], self.end_id, label=label, weight=1)
443
+ path.append(self.end_id)
444
+
445
+ for i in tqdm(range(len(self.consensus_node_ids) - 1)):
446
+ if verbose:
447
+ print(self.consensus_node_ids[i], self.consensus_node_ids[i + 1])
448
+ self.merge_paths_between(
449
+ self.consensus_node_ids[i],
450
+ self.consensus_node_ids[i + 1],
451
+ path_sim_type=path_sim_type,
452
+ verbose=verbose,
453
+ **kwargs,
454
+ )
455
+
456
+ def get_variable_node_ids(self):
457
+ return [
458
+ node.ID for node in self.nodedict.values() if node.ID not in self.consensus_node_ids
459
+ ]
460
+
461
+ def compress_paths_between(self, start_node_id: int, end_node_id: int):
462
+ pass
463
+
464
+ def compress_graph(self):
465
+ pass
466
+
467
+ def update_influence_scores(self, outcome: Dict[int, float], discount_factor: float = 0.2):
468
+ self.toposort()
469
+ direct_scores = []
470
+ for node in self.nodedict.values():
471
+ next_out_weight = sum(e.weight for e in node.outEdges.values())
472
+ next_in_weight = sum(e.weight for e in node.inEdges.values())
473
+ if next_out_weight == self.num_sequences and next_in_weight == self.num_sequences:
474
+ out_list = []
475
+ for edge in node.outEdges.values():
476
+ for _ in range(len(set(edge.labels))):
477
+ out_list.append(np.mean([outcome[label] for label in set(edge.labels)]))
478
+ direct_scores.append((node.ID, np.var(out_list)))
479
+
480
+ scores = direct_scores.copy()
481
+
482
+ # Start from the end and propagate influence backward
483
+ for i in range(len(scores) - 2, -1, -1):
484
+ # Current node gets its direct influence plus discounted influence of next node
485
+ current_direct = scores[i][1]
486
+ next_total = scores[i + 1][1]
487
+ scores[i] = (scores[i][0], current_direct + discount_factor * next_total)
488
+
489
+ scores.sort(key=lambda x: x[1], reverse=True)
490
+ return scores
491
+
492
+ def jsOutput(
493
+ self,
494
+ verbose: bool = False,
495
+ annotate_consensus: bool = True,
496
+ color_annotations: Dict[int, str] = None,
497
+ ):
498
+ """returns a list of strings containing a a description of the graph for viz.js, http://visjs.org"""
499
+
500
+ # get the consensus sequence, which we'll use as the "spine" of the
501
+ # graph
502
+ pathdict = {}
503
+ if annotate_consensus:
504
+ path, __, __ = self.consensus()
505
+ lines = ["var nodes = ["]
506
+
507
+ ni = self.nodeiterator()
508
+ count = 0
509
+ for node in ni():
510
+ title_text = ""
511
+ if node.sequences:
512
+ title_text += f"Sequences: {node.sequences}"
513
+ if node.variations:
514
+ title_text += ";;;".join(
515
+ [f"{sequence_id}: {text}" for sequence_id, text in node.variations.items()]
516
+ )
517
+ title_text = title_text.replace('"', "'")
518
+ line = (
519
+ " {id:"
520
+ + str(node.ID)
521
+ + ', label: "'
522
+ + str(node.ID)
523
+ + ": "
524
+ + node.text.replace('"', "'")
525
+ + '", title: '
526
+ + '"'
527
+ + title_text
528
+ + '",'
529
+ )
530
+ if color_annotations and node.ID in color_annotations:
531
+ line += f" color: '{color_annotations[node.ID]}', "
532
+ if node.ID in pathdict and count % 5 == 0 and annotate_consensus:
533
+ line += (
534
+ ", x: "
535
+ + str(pathdict[node.ID])
536
+ + ", y: 0 , fixed: { x:true, y:false},"
537
+ + "color: '#7BE141', is_consensus:true},"
538
+ )
539
+ else:
540
+ line += "},"
541
+ lines.append(line)
542
+
543
+ lines[-1] = lines[-1][:-1]
544
+ lines.append("];")
545
+
546
+ lines.append(" ")
547
+
548
+ lines.append("var edges = [ ")
549
+ ni = self.nodeiterator()
550
+ for node in ni():
551
+ nodeID = str(node.ID)
552
+ for edge in node.outEdges:
553
+ target = str(edge)
554
+ weight = str(node.outEdges[edge].weight + 1.5)
555
+ lines.append(
556
+ " {from: "
557
+ + nodeID
558
+ + ", to: "
559
+ + target
560
+ + ", value: "
561
+ + weight
562
+ + ", color: '#4b72b0', arrows: 'to'},"
563
+ )
564
+ if verbose:
565
+ for alignededge in node.alignedTo:
566
+ # These edges indicate alignment to different bases, and are
567
+ # undirected; thus make sure we only plot them once:
568
+ if node.ID > alignededge:
569
+ continue
570
+ target = str(alignededge)
571
+ lines.append(
572
+ " {from: "
573
+ + nodeID
574
+ + ", to: "
575
+ + target
576
+ + ', value: 1, style: "dash-line", color: "red"},'
577
+ )
578
+
579
+ lines[-1] = lines[-1][:-1]
580
+ lines.append("];")
581
+ return lines
582
+
583
+ def htmlOutput(
584
+ self,
585
+ outfile,
586
+ verbose: bool = False,
587
+ annotate_consensus: bool = True,
588
+ color_annotations: Dict[int, str] = None,
589
+ ):
590
+ header = """
591
+ <!doctype html>
592
+ <html>
593
+ <head>
594
+ <title>POA Graph Alignment</title>
595
+
596
+ <script type="text/javascript" src="https://unpkg.com/vis-network@9.0.4/standalone/umd/vis-network.min.js"></script>
597
+ </head>
598
+
599
+ <body>
600
+
601
+ <div id="loadingProgress">0%</div>
602
+
603
+ <div id="mynetwork"></div>
604
+
605
+ <script type="text/javascript">
606
+ // create a network
607
+ """
608
+ outfile.write(textwrap.dedent(header[1:]))
609
+ lines = self.jsOutput(
610
+ verbose=verbose,
611
+ annotate_consensus=annotate_consensus,
612
+ color_annotations=color_annotations,
613
+ )
614
+ for line in lines:
615
+ outfile.write(line + "\n")
616
+ footer = """
617
+ var container = document.getElementById('mynetwork');
618
+ var data= {
619
+ nodes: nodes,
620
+ edges: edges,
621
+ };
622
+ var options = {
623
+ width: '100%',
624
+ height: '800px',
625
+ physics: {
626
+ enabled: false,
627
+ stabilization: {
628
+ updateInterval: 10,
629
+ },
630
+ },
631
+ edges: {
632
+ color: {
633
+ inherit: false
634
+ }
635
+ },
636
+ layout: {
637
+ hierarchical: {
638
+ direction: "UD",
639
+ sortMethod: "directed",
640
+ shakeTowards: "roots",
641
+ levelSeparation: 150, // Adjust as needed
642
+ nodeSpacing: 800, // Adjust as needed
643
+ treeSpacing: 200, // Adjust as needed
644
+ parentCentralization: true,
645
+ }
646
+ }
647
+ };
648
+ var network = new vis.Network(container, data, options);
649
+
650
+ network.on('beforeDrawing', function(ctx) {
651
+ nodes.forEach(function(node) {
652
+ if (node.isConsensus) {
653
+ // Set the level of spine nodes to the bottom
654
+ network.body.data.nodes.update({
655
+ id: node.id,
656
+ level: 0 // Set level to 0 for spine nodes
657
+ });
658
+ }
659
+ });
660
+ });
661
+
662
+ network.on("stabilizationProgress", function (params) {
663
+ document.getElementById("loadingProgress").innerText = Math.round(params.iterations / params.total * 100) + "%";
664
+ });
665
+ network.once("stabilizationIterationsDone", function () {
666
+ document.getElementById("loadingProgress").innerText = "100%";
667
+ setTimeout(function () {
668
+ document.getElementById("loadingProgress").style.display = "none";
669
+ }, 500);
670
+ });
671
+
672
+
673
+ </script>
674
+
675
+ </body>
676
+ </html>
677
+ """
678
+ outfile.write(textwrap.dedent(footer))
679
+
680
+
681
+ def multi_consensus_response(self, abstention_threshold: Optional[float] = None, filter: bool = True):
682
+ self.toposort()
683
+ nodesInReverse = self.nodeidlist[::-1]
684
+ maxnodeID = self.end_id
685
+ nextInPath = [-1] * maxnodeID
686
+ scores = np.zeros(len(self.nodeidlist))
687
+
688
+ id_to_index = {node_id: index for index, node_id in enumerate(self.nodeidlist)}
689
+ index_to_id = {index: node_id for index, node_id in enumerate(self.nodeidlist)}
690
+
691
+ for nodeID in nodesInReverse:
692
+ bestWeightScoreEdges = [(-1, -1, None)]
693
+ for neighbourID in self.nodedict[nodeID].outEdges:
694
+ # print(f"nodeID: {nodeID}, neighbourID: {neighbourID}")
695
+ e = self.nodedict[nodeID].outEdges[neighbourID]
696
+ weightScoreEdge = (e.weight, scores[id_to_index[neighbourID]], neighbourID)
697
+
698
+
699
+ if weightScoreEdge > bestWeightScoreEdges[0]:
700
+ bestWeightScoreEdges = [weightScoreEdge]
701
+ elif weightScoreEdge == bestWeightScoreEdges[0] and filter:
702
+ bestWeightScoreEdges.append(weightScoreEdge)
703
+
704
+
705
+ scores[id_to_index[nodeID]] = sum(bestWeightScoreEdges[0][0:2])
706
+ if bestWeightScoreEdges[0][2] is not None:
707
+ nextInPath[id_to_index[nodeID]] = id_to_index[bestWeightScoreEdges[0][2]]
708
+ else:
709
+ nextInPath[id_to_index[nodeID]] = None
710
+
711
+ pos = np.argmax(scores)
712
+ path = []
713
+ text = []
714
+ labels = []
715
+
716
+ while pos is not None and pos > -1:
717
+ if abstention_threshold is not None and self.nodedict[index_to_id[pos]].variations:
718
+ if (
719
+ len(self.nodedict[index_to_id[pos]].labels) / self.num_sequences
720
+ >= abstention_threshold
721
+ ):
722
+ path.append(index_to_id[pos])
723
+ labels.append(self.nodedict[index_to_id[pos]].labels)
724
+ text.append(self.nodedict[index_to_id[pos]].text)
725
+ else:
726
+ path.append(index_to_id[pos])
727
+ labels.append(self.nodedict[index_to_id[pos]].labels)
728
+ text.append(self.nodedict[index_to_id[pos]].text)
729
+ pos = nextInPath[pos]
730
+
731
+ # ignore END node
732
+ path = path[:-1]
733
+ # ignore END node
734
+ text = text[:-1]
735
+ # ignore START in text
736
+ text[0] = text[0].replace("START", "")
737
+ labels = labels[:-1]
738
+
739
+ return " ".join(text)
740
+
741
+
742
+ def consensus_response(
743
+ self, selection_threshold: Optional[float] = 0.5, api: str = "openai" , model: str = "gpt-4o-mini", task: str = "bio", **kwargs
744
+ ) -> str:
745
+ self.toposort()
746
+
747
+ consensus_node_ids = self.consensus_node_ids
748
+ print(consensus_node_ids)
749
+
750
+ selected_node_ids = []
751
+
752
+ for node_id in consensus_node_ids:
753
+ if node_id == self.start_id or node_id == self.end_id:
754
+ continue
755
+
756
+ selected_node_ids.append(node_id)
757
+
758
+ for neighbor_id in self.nodedict[node_id].outEdges:
759
+ if neighbor_id in consensus_node_ids:
760
+ continue
761
+
762
+ if (
763
+ len(self.nodedict[neighbor_id].labels) / self.num_sequences
764
+ >= selection_threshold
765
+ ):
766
+ selected_node_ids.append(neighbor_id)
767
+
768
+ text = " ".join([self.nodedict[node_id].text for node_id in selected_node_ids])
769
+ print(text)
770
+ cleaned_text = clean_up_text(text, task=task, api=api, model=model, **kwargs)
771
+ return cleaned_text
772
+
773
+ def save_to_pickle(self, filename):
774
+ with open(filename, "wb+") as f:
775
+ pickle.dump(self, f)
776
+
777
+ def refine_graph(
778
+ self,
779
+ verbose: bool = False,
780
+ save_intermediate_file: str = None,
781
+ final_merge: bool = True,
782
+ **kwargs,
783
+ ):
784
+ self.merge_consensus_nodes(verbose=verbose)
785
+
786
+ if save_intermediate_file:
787
+ with open(save_intermediate_file, "w+") as f:
788
+ self.htmlOutput(f, annotate_consensus=False)
789
+
790
+ if not self.consensus_node_ids:
791
+ self.failed = True
792
+ return
793
+
794
+ else:
795
+ self.merge_divergent_paths(verbose=verbose, **kwargs)
796
+
797
+ if final_merge:
798
+ try:
799
+ self.merge_consensus_nodes(verbose=verbose)
800
+ except Exception as e:
801
+ print(e)
802
+ self.failed = True
src/text_poa_graph_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from huggingface_hub import InferenceClient
4
+ from openai import OpenAI
5
+
6
+
7
+ TEXT_SIMILARITY_JUDGE_PROMPT = """
8
+ You are given two pieces of text. Your task is to determine whether they are semantically equivalent based solely on their factual content.
9
+
10
+ Here are the specific guidelines:
11
+ - Texts are equivalent if they convey the same core information or concept, regardless of wording or structure
12
+ - If one text has information that is a subset of the other text, then the texts are equivalent
13
+ - Focus ONLY on the essential claims, not on:
14
+ * Stylistic differences or tone
15
+ * Level of detail (if the core facts remain the same)
16
+ * Connotative differences between words
17
+ * Implied significance or emphasis
18
+ * Presentation order (if all key information is present in both)
19
+ - Minor additions of non-contradictory information should not make texts non-equivalent
20
+ - For ambiguous cases, prioritize the central claim or purpose of the text
21
+
22
+ Examples of equivalent pairs:
23
+ - "The meeting starts at 3pm" and "The 3 o'clock meeting will begin on time"
24
+ - "Research indicates a 15% increase" and "Studies show a fifteen percent rise"
25
+ - "was influential in the field" and "had a significant impact on the community"
26
+
27
+ Examples of non-equivalent pairs:
28
+ - "The project might be completed by Friday" and "The project will be finished by Friday"
29
+ - "Most experts agree on the approach" and "All experts support the approach"
30
+
31
+ Strictly follow these guidelines and return ONLY:
32
+ - equivalent
33
+ - not equivalent
34
+ """
35
+ MATH_SIMILARITY_JUDGE_PROMPT = """
36
+ You are given two pieces of text from mathematical solutions. Your task is to determine whether the two solution segments are mathematically equivalent in their content, while allowing for stylistic variations.
37
+
38
+ Here are some important guidelines:
39
+ - Solutions should be considered equivalent if:
40
+ 1. They communicate the same mathematical content/approach, even if word choice or phrasing differs
41
+ 2. They contain the same key mathematical ideas, even if expressed differently
42
+ 3. The same mathematical steps are described, even if using different words
43
+ 4. They present the same final answer, regardless of wording style or formatting
44
+
45
+ - Allow for these variations while still considering solutions equivalent:
46
+ 1. Stylistic differences ("we will" vs. "we'll" or "I'll")
47
+ 2. Different levels of formality in the explanation
48
+ 3. Minor rephrasing that preserves the core mathematical content
49
+ 4. Use of synonyms or alternative mathematical terminology for the same concept
50
+
51
+ - Solutions are NOT equivalent if:
52
+ 1. They use fundamentally different mathematical approaches
53
+ 2. They work with different formulas or equations
54
+ 3. They present different mathematical steps or operations
55
+ 4. They reach different conclusions or answers
56
+ 5. One contains substantial mathematical content that the other lacks
57
+
58
+ - When examining final answers, focus on mathematical equivalence rather than stylistic presentation
59
+ - For solution steps, maintain the core mathematical approach while allowing for rephrasing
60
+
61
+ Examples of solutions that SHOULD be considered equivalent:
62
+ - "We will systematically evaluate each possible grouping" and "We'll evaluate each grouping"
63
+ - "The answer is x = 5" and "Therefore, x equals 5"
64
+ - "Using the quadratic formula" and "Applying the quadratic formula"
65
+
66
+ Strictly follow the guidelines above.
67
+ Return your judgment in the following format. Do not include any other text:
68
+ - equivalent
69
+ - not equivalent
70
+ """
71
+
72
+ def path_sim_llm(
73
+ path1_text: str,
74
+ path2_text: str,
75
+ api: str = "openai",
76
+ model: str = "gpt-4.1-mini",
77
+ verbose: bool = False,
78
+ domain: Optional[str] = "text",
79
+ custom_similarity_judge_prompt: str = None,
80
+ ):
81
+ if api == "openai":
82
+ client = OpenAI()
83
+ elif api == "hf":
84
+ client = InferenceClient()
85
+ else:
86
+ raise ValueError(f"Invalid API: {api}")
87
+
88
+ if domain == "text":
89
+ similarity_judge_prompt = (
90
+ f"{TEXT_SIMILARITY_JUDGE_PROMPT}\n\nText 1: {path1_text}\nText 2: {path2_text}"
91
+ )
92
+ elif domain == "math":
93
+ similarity_judge_prompt = (
94
+ f"{MATH_SIMILARITY_JUDGE_PROMPT}\n\nText 1: {path1_text}\nText 2: {path2_text}"
95
+ )
96
+ elif not domain and custom_similarity_judge_prompt:
97
+ similarity_judge_prompt = (
98
+ f"{custom_similarity_judge_prompt}\n\nText 1: {path1_text}\nText 2: {path2_text}"
99
+ )
100
+ else:
101
+ raise ValueError(f"Invalid domain: {domain} and no custom similarity judge prompt provided")
102
+
103
+ completion = client.chat.completions.create(
104
+ model=model,
105
+ temperature=0,
106
+ messages=[
107
+ {"role": "system", "content": "You are a helpful assistant."},
108
+ {"role": "user", "content": similarity_judge_prompt},
109
+ ],
110
+ )
111
+
112
+ judgement = completion.choices[0].message.content.strip()
113
+ judgement = "".join(c for c in judgement if c.isalpha() or c == " ")
114
+ judgement = judgement.strip()
115
+
116
+ if verbose:
117
+ print(f"{path1_text} \nand \n{path2_text} \nare {judgement}")
118
+
119
+ if judgement == "equivalent":
120
+ return 1, completion.usage.prompt_tokens, completion.usage.completion_tokens
121
+ elif judgement == "not equivalent":
122
+ return 0, completion.usage.prompt_tokens, completion.usage.completion_tokens
123
+ else:
124
+ if verbose:
125
+ print(f"Invalid judgement: {judgement}")
126
+ return 0, completion.usage.prompt_tokens, completion.usage.completion_tokens
src/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from huggingface_hub import InferenceClient
4
+ from openai import OpenAI
5
+
6
+
7
+ def detect_abstain(text: str, api: str, model: str):
8
+ if api == "openai":
9
+ client = OpenAI()
10
+ elif api == "hf":
11
+ client = InferenceClient()
12
+ else:
13
+ raise ValueError(f"Invalid API: {api}")
14
+
15
+ detect_abstain_prompt = f"""
16
+ You are given a piece of text that is a part of a biography of an entity.
17
+ Text: {text}
18
+
19
+ If the text claims a lack of knowledge about the topic, return "Abstain".
20
+ Otherwise, return "Not abstain".
21
+ """
22
+
23
+ completion = client.chat.completions.create(
24
+ model=model,
25
+ messages=[
26
+ {"role": "system", "content": "You are a helpful assistant."},
27
+ {"role": "user", "content": detect_abstain_prompt},
28
+ ],
29
+ )
30
+
31
+ return completion.choices[0].message.content.strip()
32
+
33
+
34
+ def calculate_factf1_at_k(
35
+ supported_facts: List[str], unsupported_facts: List[str], k: int
36
+ ) -> float:
37
+ """
38
+ Calculate the F1 score at k for supported and unsupported facts
39
+ """
40
+ if len(supported_facts) == 0:
41
+ return 0
42
+
43
+ precision = len(supported_facts) / (len(supported_facts) + len(unsupported_facts))
44
+ recall = min(len(supported_facts) / k, 1)
45
+ f1 = 2 * precision * recall / (precision + recall)
46
+ return f1