File size: 12,521 Bytes
d2ff6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
from typing import Optional

from src.generation_utils import (
    extract_alternative_paths,
    extract_context,
    extract_equivalent_classes,
    self_complete,
    verify_correctness_pairwise,
)
from src.global_edit_utils import clean_up_text
from src.text_poa_graph import TextPOAGraph

"""
Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold.
Only the primary variation of selected variable nodes are selected.
Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies).

Args:
    text_poa_graph: The TextPOAGraph object to decode.
    selection_threshold: The threshold for selecting nodes.
    model: The model to use for decoding.

Returns:
    A string of the decoded text.
"""


def decode_consensus(
    text_poa_graph: TextPOAGraph,
    selection_threshold: Optional[float] = 0.5,
    task: str = "bio",
    verbose: bool = False,
    **kwargs,
) -> str:
    if text_poa_graph.failed:
        return "Abstain"

    text_poa_graph.toposort()

    consensus_node_ids = text_poa_graph.consensus_node_ids

    selected_node_ids = []

    for node_id in consensus_node_ids:
        if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
            continue

        selected_node_ids.append(node_id)

        for neighbor_id in text_poa_graph.nodedict[node_id].outEdges:
            if neighbor_id in consensus_node_ids:
                continue

            if (
                len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences
                >= selection_threshold
            ):
                selected_node_ids.append(neighbor_id)

    texts = []
    for node_id in selected_node_ids:
        if not text_poa_graph.nodedict[node_id].variations:
            texts.append(text_poa_graph.nodedict[node_id].text)
        else:
            all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()]
            all_texts.append(text_poa_graph.nodedict[node_id].text)
            # select the variation that is longest
            texts.append(max(all_texts, key=len))
    text = " ".join(texts)
    edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs)
    if verbose:
        return text, edited_text
    else:
        return edited_text


def decode_self_verified(
    text_poa_graph: TextPOAGraph,
    problem: str,
    uncertainty_threshold: float = 0.6,
    verification_api: str = "openai",
    verification_model: str = "gpt-4o-mini",
    grace_period: bool = True,
):
    high_uncertainty_nodes = []
    for node_id in text_poa_graph.consensus_node_ids:
        if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
            continue

        outgoing_edges = text_poa_graph.nodedict[node_id].outEdges
        branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences

        if branching_factor > uncertainty_threshold:
            high_uncertainty_nodes.append(node_id)

    selected_labels = list(text_poa_graph._seq_paths.keys())
    masked_candidates = {}
    uncertain_region = False
    for label in selected_labels:
        text = ""
        for node_id in text_poa_graph._seq_paths[label]:
            if uncertain_region:
                text += f" *START_SEPARATOR*_{node_id} "
            if node_id in high_uncertainty_nodes:
                uncertain_region = True

            if len(text_poa_graph.nodedict[node_id].variations) > 0:
                text += text_poa_graph.nodedict[node_id].variations[label]
                text += " "
            else:
                text += text_poa_graph.nodedict[node_id].text
                text += " "

            if uncertain_region and node_id not in high_uncertainty_nodes:
                text += f" *END_SEPARATOR*_{node_id} "
                uncertain_region = False
        masked_candidates[label] = text

    patch_start_node = None
    uncertain_ids = []

    # give a grace period for the first incorrect step
    prev_step = {label: None for label in selected_labels}

    for node_id in high_uncertainty_nodes:
        uncertain_ids.append(node_id)
        context_before = extract_context(text_poa_graph, node_id)
        alternative_paths = extract_alternative_paths(text_poa_graph, node_id)
        equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels)
        new_labels = selected_labels.copy()

        # Only do self-verifaction for labels from different sematically equivalent branches
        if len(equivalent_classes) <= 1:
            continue
        i = 0
        while i < len(equivalent_classes):
            if i + 1 < len(equivalent_classes):
                label_a = equivalent_classes[i][0]
                label_b = equivalent_classes[i + 1][0]
                full_a = context_before[label_a] + alternative_paths[label_a]
                full_b = context_before[label_b] + alternative_paths[label_b]

                score = verify_correctness_pairwise(
                    full_text_1=full_a,
                    full_text_2=full_b,
                    verification_model=verification_model,
                    problem=problem,
                    api=verification_api,
                )
                if float(score[0]) < 1.0:
                    print(f"Label {label_a} is incorrect at node {node_id}")
                    masked_candidates[label_a] = (
                        masked_candidates[label_a]
                        .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
                        .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
                    )
                    if not prev_step[label_a]:
                        prev_step[label_a] = True
                    if prev_step[label_a] and grace_period or not grace_period:
                        for label_i in equivalent_classes[i]:
                            new_labels.remove(label_i)
                            print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
                if float(score[0]) == 1.0:
                    prev_step[label_a] = False
                if float(score[1]) < 1.0:
                    print(f"Label {label_b} is incorrect at node {node_id}")
                    masked_candidates[label_b] = (
                        masked_candidates[label_b]
                        .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
                        .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
                    )
                    if not prev_step[label_b]:
                        prev_step[label_b] = True
                    if prev_step[label_b] and grace_period or not grace_period:
                        for label_i in equivalent_classes[i + 1]:
                            new_labels.remove(label_i)
                            print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
                if float(score[1]) == 1.0:
                    prev_step[label_b] = False
                i += 2
            else:
                break

        if len(new_labels) == 0:
            patch_start_node = node_id
            break

        selected_labels = new_labels.copy()

    # These are the pruned approaches with masking
    print(masked_candidates)
    masked_approaches = "\n".join(
        [
            f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}"
            for label in selected_labels
        ]
    )
    # These are all approaches with masking
    all_approaches = "\n".join(
        [f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()]
    )

    default_prompt = f"""
    Solve the following math problem with mathematical precision and clarity.

    Problem: {problem}

    Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*). 
    These sections may contain conceptual or computational errors.

    There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*.
    A verification step indicated that these steps are highly likely to contain errors.

    Potential Approaches:
    {masked_approaches}

    Your task:
    1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses
       If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors.
    2. Using the sections with special markers, identify potential errors.
    3. Develop a rigorous, step-by-step solution based on sound mathematical principles
    4. For uncertain regions:
       - Verify each step using algebraic or numerical validation
       - If correct, incorporate these steps with appropriate justification
       - If incorrect, provide clear corrections with mathematical reasoning for your changes
    5. Follow a comparative approach, using the differences between approaches to identify potential errors.
    6. Do not blindly follow the approaches, but rather use them to identify potential errors.

    Guidelines for your solution:
    - Begin with a strategic overview of your chosen approach
    - Present each mathematical step with clear notation and justification
    - Pay special attention to areas that were previously marked uncertain

    Conclude your solution with:
    Therefore, the final answer is: $\\boxed{{answer}}$.

    Solution:
    """

    patch_prompt = f"""
    Solve the following mathematical problem with precision and clarity.

    Problem: {problem}

    You have been provided with several partial solution approaches that attempted to solve this problem. 
    None of these approaches are correct, but may contain valuable insights.
    Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty.
    A verification step indicated that these steps are likely to contain errors.

    INSTRUCTIONS:
    1. Synthesize a correct solution using insights from the previous approaches
    2. Pay special attention to fixing the problematic areas marked by separators
    3. Develop your solution step-by-step, showing clear mathematical reasoning
    4. Focus especially on mathematical correctness in areas where previous solutions diverged
    5. Present your work in a logical, sequential manner suitable for an advanced reader

    GUIDELINES FOR MATHEMATICAL RIGOR:
    1. MAINTAIN MATHEMATICAL RIGOR
    - Verify that all mathematical operations follow from established principles and definitions
    - Ensure dimensional consistency throughout calculations
    - Check that algebraic manipulations preserve equality and do not introduce errors
   
    2. CONSIDER ALTERNATIVE PERSPECTIVES
    - Even when approaches reach the same conclusion, examine their reasoning independently
    - Look for more elegant or insightful connections that may be missed across all approaches
    - Consider whether fundamental mathematical principles suggest a different path
   
    3. CRITICAL VALIDATION
    - Test conclusions using known mathematical properties and relationships
    - When possible, verify results using alternative methods
    - Be especially cautious when all approaches agree on a result but use similar reasoning
   
    4. USE PRECISION IN CORRECTIONS
    - When correcting uncertain regions, specify exactly what was incorrect and why
    - Provide clear mathematical justification for any changes
    - Ensure corrections align with standard mathematical principles and notations

    Previous Approaches (for reference only):
{all_approaches}

Your Solution:
[Begin with a clear statement of your approach]
[Provide detailed mathematical steps]
[Ensure correct handling of complex mathematical operations]
[Verify your work at key points, especially in previously problematic areas]

Always conclude with:
Therefore, the final answer is: $\\boxed{{answer}}$ 
    """

    if patch_start_node is not None or len(masked_candidates.keys()) == 1:
        print("None correct, patching")
        prompt = patch_prompt
    else:
        prompt = default_prompt

    return self_complete(
        verification_prompt=prompt, verification_model=verification_model, api=verification_api
    ), masked_candidates