congr-visualizer / src /generation_methods.py
Shahzaib98's picture
Upload 11 files
d2ff6a7 verified
from typing import Optional
from src.generation_utils import (
extract_alternative_paths,
extract_context,
extract_equivalent_classes,
self_complete,
verify_correctness_pairwise,
)
from src.global_edit_utils import clean_up_text
from src.text_poa_graph import TextPOAGraph
"""
Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold.
Only the primary variation of selected variable nodes are selected.
Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies).
Args:
text_poa_graph: The TextPOAGraph object to decode.
selection_threshold: The threshold for selecting nodes.
model: The model to use for decoding.
Returns:
A string of the decoded text.
"""
def decode_consensus(
text_poa_graph: TextPOAGraph,
selection_threshold: Optional[float] = 0.5,
task: str = "bio",
verbose: bool = False,
**kwargs,
) -> str:
if text_poa_graph.failed:
return "Abstain"
text_poa_graph.toposort()
consensus_node_ids = text_poa_graph.consensus_node_ids
selected_node_ids = []
for node_id in consensus_node_ids:
if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
continue
selected_node_ids.append(node_id)
for neighbor_id in text_poa_graph.nodedict[node_id].outEdges:
if neighbor_id in consensus_node_ids:
continue
if (
len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences
>= selection_threshold
):
selected_node_ids.append(neighbor_id)
texts = []
for node_id in selected_node_ids:
if not text_poa_graph.nodedict[node_id].variations:
texts.append(text_poa_graph.nodedict[node_id].text)
else:
all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()]
all_texts.append(text_poa_graph.nodedict[node_id].text)
# select the variation that is longest
texts.append(max(all_texts, key=len))
text = " ".join(texts)
edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs)
if verbose:
return text, edited_text
else:
return edited_text
def decode_self_verified(
text_poa_graph: TextPOAGraph,
problem: str,
uncertainty_threshold: float = 0.6,
verification_api: str = "openai",
verification_model: str = "gpt-4o-mini",
grace_period: bool = True,
):
high_uncertainty_nodes = []
for node_id in text_poa_graph.consensus_node_ids:
if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
continue
outgoing_edges = text_poa_graph.nodedict[node_id].outEdges
branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences
if branching_factor > uncertainty_threshold:
high_uncertainty_nodes.append(node_id)
selected_labels = list(text_poa_graph._seq_paths.keys())
masked_candidates = {}
uncertain_region = False
for label in selected_labels:
text = ""
for node_id in text_poa_graph._seq_paths[label]:
if uncertain_region:
text += f" *START_SEPARATOR*_{node_id} "
if node_id in high_uncertainty_nodes:
uncertain_region = True
if len(text_poa_graph.nodedict[node_id].variations) > 0:
text += text_poa_graph.nodedict[node_id].variations[label]
text += " "
else:
text += text_poa_graph.nodedict[node_id].text
text += " "
if uncertain_region and node_id not in high_uncertainty_nodes:
text += f" *END_SEPARATOR*_{node_id} "
uncertain_region = False
masked_candidates[label] = text
patch_start_node = None
uncertain_ids = []
# give a grace period for the first incorrect step
prev_step = {label: None for label in selected_labels}
for node_id in high_uncertainty_nodes:
uncertain_ids.append(node_id)
context_before = extract_context(text_poa_graph, node_id)
alternative_paths = extract_alternative_paths(text_poa_graph, node_id)
equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels)
new_labels = selected_labels.copy()
# Only do self-verifaction for labels from different sematically equivalent branches
if len(equivalent_classes) <= 1:
continue
i = 0
while i < len(equivalent_classes):
if i + 1 < len(equivalent_classes):
label_a = equivalent_classes[i][0]
label_b = equivalent_classes[i + 1][0]
full_a = context_before[label_a] + alternative_paths[label_a]
full_b = context_before[label_b] + alternative_paths[label_b]
score = verify_correctness_pairwise(
full_text_1=full_a,
full_text_2=full_b,
verification_model=verification_model,
problem=problem,
api=verification_api,
)
if float(score[0]) < 1.0:
print(f"Label {label_a} is incorrect at node {node_id}")
masked_candidates[label_a] = (
masked_candidates[label_a]
.replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
.replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
)
if not prev_step[label_a]:
prev_step[label_a] = True
if prev_step[label_a] and grace_period or not grace_period:
for label_i in equivalent_classes[i]:
new_labels.remove(label_i)
print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
if float(score[0]) == 1.0:
prev_step[label_a] = False
if float(score[1]) < 1.0:
print(f"Label {label_b} is incorrect at node {node_id}")
masked_candidates[label_b] = (
masked_candidates[label_b]
.replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
.replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
)
if not prev_step[label_b]:
prev_step[label_b] = True
if prev_step[label_b] and grace_period or not grace_period:
for label_i in equivalent_classes[i + 1]:
new_labels.remove(label_i)
print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
if float(score[1]) == 1.0:
prev_step[label_b] = False
i += 2
else:
break
if len(new_labels) == 0:
patch_start_node = node_id
break
selected_labels = new_labels.copy()
# These are the pruned approaches with masking
print(masked_candidates)
masked_approaches = "\n".join(
[
f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}"
for label in selected_labels
]
)
# These are all approaches with masking
all_approaches = "\n".join(
[f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()]
)
default_prompt = f"""
Solve the following math problem with mathematical precision and clarity.
Problem: {problem}
Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*).
These sections may contain conceptual or computational errors.
There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*.
A verification step indicated that these steps are highly likely to contain errors.
Potential Approaches:
{masked_approaches}
Your task:
1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses
If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors.
2. Using the sections with special markers, identify potential errors.
3. Develop a rigorous, step-by-step solution based on sound mathematical principles
4. For uncertain regions:
- Verify each step using algebraic or numerical validation
- If correct, incorporate these steps with appropriate justification
- If incorrect, provide clear corrections with mathematical reasoning for your changes
5. Follow a comparative approach, using the differences between approaches to identify potential errors.
6. Do not blindly follow the approaches, but rather use them to identify potential errors.
Guidelines for your solution:
- Begin with a strategic overview of your chosen approach
- Present each mathematical step with clear notation and justification
- Pay special attention to areas that were previously marked uncertain
Conclude your solution with:
Therefore, the final answer is: $\\boxed{{answer}}$.
Solution:
"""
patch_prompt = f"""
Solve the following mathematical problem with precision and clarity.
Problem: {problem}
You have been provided with several partial solution approaches that attempted to solve this problem.
None of these approaches are correct, but may contain valuable insights.
Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty.
A verification step indicated that these steps are likely to contain errors.
INSTRUCTIONS:
1. Synthesize a correct solution using insights from the previous approaches
2. Pay special attention to fixing the problematic areas marked by separators
3. Develop your solution step-by-step, showing clear mathematical reasoning
4. Focus especially on mathematical correctness in areas where previous solutions diverged
5. Present your work in a logical, sequential manner suitable for an advanced reader
GUIDELINES FOR MATHEMATICAL RIGOR:
1. MAINTAIN MATHEMATICAL RIGOR
- Verify that all mathematical operations follow from established principles and definitions
- Ensure dimensional consistency throughout calculations
- Check that algebraic manipulations preserve equality and do not introduce errors
2. CONSIDER ALTERNATIVE PERSPECTIVES
- Even when approaches reach the same conclusion, examine their reasoning independently
- Look for more elegant or insightful connections that may be missed across all approaches
- Consider whether fundamental mathematical principles suggest a different path
3. CRITICAL VALIDATION
- Test conclusions using known mathematical properties and relationships
- When possible, verify results using alternative methods
- Be especially cautious when all approaches agree on a result but use similar reasoning
4. USE PRECISION IN CORRECTIONS
- When correcting uncertain regions, specify exactly what was incorrect and why
- Provide clear mathematical justification for any changes
- Ensure corrections align with standard mathematical principles and notations
Previous Approaches (for reference only):
{all_approaches}
Your Solution:
[Begin with a clear statement of your approach]
[Provide detailed mathematical steps]
[Ensure correct handling of complex mathematical operations]
[Verify your work at key points, especially in previously problematic areas]
Always conclude with:
Therefore, the final answer is: $\\boxed{{answer}}$
"""
if patch_start_node is not None or len(masked_candidates.keys()) == 1:
print("None correct, patching")
prompt = patch_prompt
else:
prompt = default_prompt
return self_complete(
verification_prompt=prompt, verification_model=verification_model, api=verification_api
), masked_candidates