Spaces:
Build error
Build error
File size: 12,521 Bytes
d2ff6a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
from typing import Optional
from src.generation_utils import (
extract_alternative_paths,
extract_context,
extract_equivalent_classes,
self_complete,
verify_correctness_pairwise,
)
from src.global_edit_utils import clean_up_text
from src.text_poa_graph import TextPOAGraph
"""
Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold.
Only the primary variation of selected variable nodes are selected.
Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies).
Args:
text_poa_graph: The TextPOAGraph object to decode.
selection_threshold: The threshold for selecting nodes.
model: The model to use for decoding.
Returns:
A string of the decoded text.
"""
def decode_consensus(
text_poa_graph: TextPOAGraph,
selection_threshold: Optional[float] = 0.5,
task: str = "bio",
verbose: bool = False,
**kwargs,
) -> str:
if text_poa_graph.failed:
return "Abstain"
text_poa_graph.toposort()
consensus_node_ids = text_poa_graph.consensus_node_ids
selected_node_ids = []
for node_id in consensus_node_ids:
if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
continue
selected_node_ids.append(node_id)
for neighbor_id in text_poa_graph.nodedict[node_id].outEdges:
if neighbor_id in consensus_node_ids:
continue
if (
len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences
>= selection_threshold
):
selected_node_ids.append(neighbor_id)
texts = []
for node_id in selected_node_ids:
if not text_poa_graph.nodedict[node_id].variations:
texts.append(text_poa_graph.nodedict[node_id].text)
else:
all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()]
all_texts.append(text_poa_graph.nodedict[node_id].text)
# select the variation that is longest
texts.append(max(all_texts, key=len))
text = " ".join(texts)
edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs)
if verbose:
return text, edited_text
else:
return edited_text
def decode_self_verified(
text_poa_graph: TextPOAGraph,
problem: str,
uncertainty_threshold: float = 0.6,
verification_api: str = "openai",
verification_model: str = "gpt-4o-mini",
grace_period: bool = True,
):
high_uncertainty_nodes = []
for node_id in text_poa_graph.consensus_node_ids:
if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
continue
outgoing_edges = text_poa_graph.nodedict[node_id].outEdges
branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences
if branching_factor > uncertainty_threshold:
high_uncertainty_nodes.append(node_id)
selected_labels = list(text_poa_graph._seq_paths.keys())
masked_candidates = {}
uncertain_region = False
for label in selected_labels:
text = ""
for node_id in text_poa_graph._seq_paths[label]:
if uncertain_region:
text += f" *START_SEPARATOR*_{node_id} "
if node_id in high_uncertainty_nodes:
uncertain_region = True
if len(text_poa_graph.nodedict[node_id].variations) > 0:
text += text_poa_graph.nodedict[node_id].variations[label]
text += " "
else:
text += text_poa_graph.nodedict[node_id].text
text += " "
if uncertain_region and node_id not in high_uncertainty_nodes:
text += f" *END_SEPARATOR*_{node_id} "
uncertain_region = False
masked_candidates[label] = text
patch_start_node = None
uncertain_ids = []
# give a grace period for the first incorrect step
prev_step = {label: None for label in selected_labels}
for node_id in high_uncertainty_nodes:
uncertain_ids.append(node_id)
context_before = extract_context(text_poa_graph, node_id)
alternative_paths = extract_alternative_paths(text_poa_graph, node_id)
equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels)
new_labels = selected_labels.copy()
# Only do self-verifaction for labels from different sematically equivalent branches
if len(equivalent_classes) <= 1:
continue
i = 0
while i < len(equivalent_classes):
if i + 1 < len(equivalent_classes):
label_a = equivalent_classes[i][0]
label_b = equivalent_classes[i + 1][0]
full_a = context_before[label_a] + alternative_paths[label_a]
full_b = context_before[label_b] + alternative_paths[label_b]
score = verify_correctness_pairwise(
full_text_1=full_a,
full_text_2=full_b,
verification_model=verification_model,
problem=problem,
api=verification_api,
)
if float(score[0]) < 1.0:
print(f"Label {label_a} is incorrect at node {node_id}")
masked_candidates[label_a] = (
masked_candidates[label_a]
.replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
.replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
)
if not prev_step[label_a]:
prev_step[label_a] = True
if prev_step[label_a] and grace_period or not grace_period:
for label_i in equivalent_classes[i]:
new_labels.remove(label_i)
print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
if float(score[0]) == 1.0:
prev_step[label_a] = False
if float(score[1]) < 1.0:
print(f"Label {label_b} is incorrect at node {node_id}")
masked_candidates[label_b] = (
masked_candidates[label_b]
.replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
.replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
)
if not prev_step[label_b]:
prev_step[label_b] = True
if prev_step[label_b] and grace_period or not grace_period:
for label_i in equivalent_classes[i + 1]:
new_labels.remove(label_i)
print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
if float(score[1]) == 1.0:
prev_step[label_b] = False
i += 2
else:
break
if len(new_labels) == 0:
patch_start_node = node_id
break
selected_labels = new_labels.copy()
# These are the pruned approaches with masking
print(masked_candidates)
masked_approaches = "\n".join(
[
f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}"
for label in selected_labels
]
)
# These are all approaches with masking
all_approaches = "\n".join(
[f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()]
)
default_prompt = f"""
Solve the following math problem with mathematical precision and clarity.
Problem: {problem}
Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*).
These sections may contain conceptual or computational errors.
There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*.
A verification step indicated that these steps are highly likely to contain errors.
Potential Approaches:
{masked_approaches}
Your task:
1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses
If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors.
2. Using the sections with special markers, identify potential errors.
3. Develop a rigorous, step-by-step solution based on sound mathematical principles
4. For uncertain regions:
- Verify each step using algebraic or numerical validation
- If correct, incorporate these steps with appropriate justification
- If incorrect, provide clear corrections with mathematical reasoning for your changes
5. Follow a comparative approach, using the differences between approaches to identify potential errors.
6. Do not blindly follow the approaches, but rather use them to identify potential errors.
Guidelines for your solution:
- Begin with a strategic overview of your chosen approach
- Present each mathematical step with clear notation and justification
- Pay special attention to areas that were previously marked uncertain
Conclude your solution with:
Therefore, the final answer is: $\\boxed{{answer}}$.
Solution:
"""
patch_prompt = f"""
Solve the following mathematical problem with precision and clarity.
Problem: {problem}
You have been provided with several partial solution approaches that attempted to solve this problem.
None of these approaches are correct, but may contain valuable insights.
Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty.
A verification step indicated that these steps are likely to contain errors.
INSTRUCTIONS:
1. Synthesize a correct solution using insights from the previous approaches
2. Pay special attention to fixing the problematic areas marked by separators
3. Develop your solution step-by-step, showing clear mathematical reasoning
4. Focus especially on mathematical correctness in areas where previous solutions diverged
5. Present your work in a logical, sequential manner suitable for an advanced reader
GUIDELINES FOR MATHEMATICAL RIGOR:
1. MAINTAIN MATHEMATICAL RIGOR
- Verify that all mathematical operations follow from established principles and definitions
- Ensure dimensional consistency throughout calculations
- Check that algebraic manipulations preserve equality and do not introduce errors
2. CONSIDER ALTERNATIVE PERSPECTIVES
- Even when approaches reach the same conclusion, examine their reasoning independently
- Look for more elegant or insightful connections that may be missed across all approaches
- Consider whether fundamental mathematical principles suggest a different path
3. CRITICAL VALIDATION
- Test conclusions using known mathematical properties and relationships
- When possible, verify results using alternative methods
- Be especially cautious when all approaches agree on a result but use similar reasoning
4. USE PRECISION IN CORRECTIONS
- When correcting uncertain regions, specify exactly what was incorrect and why
- Provide clear mathematical justification for any changes
- Ensure corrections align with standard mathematical principles and notations
Previous Approaches (for reference only):
{all_approaches}
Your Solution:
[Begin with a clear statement of your approach]
[Provide detailed mathematical steps]
[Ensure correct handling of complex mathematical operations]
[Verify your work at key points, especially in previously problematic areas]
Always conclude with:
Therefore, the final answer is: $\\boxed{{answer}}$
"""
if patch_start_node is not None or len(masked_candidates.keys()) == 1:
print("None correct, patching")
prompt = patch_prompt
else:
prompt = default_prompt
return self_complete(
verification_prompt=prompt, verification_model=verification_model, api=verification_api
), masked_candidates
|