Spaces:
Build error
Build error
| from typing import Optional | |
| from src.generation_utils import ( | |
| extract_alternative_paths, | |
| extract_context, | |
| extract_equivalent_classes, | |
| self_complete, | |
| verify_correctness_pairwise, | |
| ) | |
| from src.global_edit_utils import clean_up_text | |
| from src.text_poa_graph import TextPOAGraph | |
| """ | |
| Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold. | |
| Only the primary variation of selected variable nodes are selected. | |
| Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies). | |
| Args: | |
| text_poa_graph: The TextPOAGraph object to decode. | |
| selection_threshold: The threshold for selecting nodes. | |
| model: The model to use for decoding. | |
| Returns: | |
| A string of the decoded text. | |
| """ | |
| def decode_consensus( | |
| text_poa_graph: TextPOAGraph, | |
| selection_threshold: Optional[float] = 0.5, | |
| task: str = "bio", | |
| verbose: bool = False, | |
| **kwargs, | |
| ) -> str: | |
| if text_poa_graph.failed: | |
| return "Abstain" | |
| text_poa_graph.toposort() | |
| consensus_node_ids = text_poa_graph.consensus_node_ids | |
| selected_node_ids = [] | |
| for node_id in consensus_node_ids: | |
| if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id: | |
| continue | |
| selected_node_ids.append(node_id) | |
| for neighbor_id in text_poa_graph.nodedict[node_id].outEdges: | |
| if neighbor_id in consensus_node_ids: | |
| continue | |
| if ( | |
| len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences | |
| >= selection_threshold | |
| ): | |
| selected_node_ids.append(neighbor_id) | |
| texts = [] | |
| for node_id in selected_node_ids: | |
| if not text_poa_graph.nodedict[node_id].variations: | |
| texts.append(text_poa_graph.nodedict[node_id].text) | |
| else: | |
| all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()] | |
| all_texts.append(text_poa_graph.nodedict[node_id].text) | |
| # select the variation that is longest | |
| texts.append(max(all_texts, key=len)) | |
| text = " ".join(texts) | |
| edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs) | |
| if verbose: | |
| return text, edited_text | |
| else: | |
| return edited_text | |
| def decode_self_verified( | |
| text_poa_graph: TextPOAGraph, | |
| problem: str, | |
| uncertainty_threshold: float = 0.6, | |
| verification_api: str = "openai", | |
| verification_model: str = "gpt-4o-mini", | |
| grace_period: bool = True, | |
| ): | |
| high_uncertainty_nodes = [] | |
| for node_id in text_poa_graph.consensus_node_ids: | |
| if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id: | |
| continue | |
| outgoing_edges = text_poa_graph.nodedict[node_id].outEdges | |
| branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences | |
| if branching_factor > uncertainty_threshold: | |
| high_uncertainty_nodes.append(node_id) | |
| selected_labels = list(text_poa_graph._seq_paths.keys()) | |
| masked_candidates = {} | |
| uncertain_region = False | |
| for label in selected_labels: | |
| text = "" | |
| for node_id in text_poa_graph._seq_paths[label]: | |
| if uncertain_region: | |
| text += f" *START_SEPARATOR*_{node_id} " | |
| if node_id in high_uncertainty_nodes: | |
| uncertain_region = True | |
| if len(text_poa_graph.nodedict[node_id].variations) > 0: | |
| text += text_poa_graph.nodedict[node_id].variations[label] | |
| text += " " | |
| else: | |
| text += text_poa_graph.nodedict[node_id].text | |
| text += " " | |
| if uncertain_region and node_id not in high_uncertainty_nodes: | |
| text += f" *END_SEPARATOR*_{node_id} " | |
| uncertain_region = False | |
| masked_candidates[label] = text | |
| patch_start_node = None | |
| uncertain_ids = [] | |
| # give a grace period for the first incorrect step | |
| prev_step = {label: None for label in selected_labels} | |
| for node_id in high_uncertainty_nodes: | |
| uncertain_ids.append(node_id) | |
| context_before = extract_context(text_poa_graph, node_id) | |
| alternative_paths = extract_alternative_paths(text_poa_graph, node_id) | |
| equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels) | |
| new_labels = selected_labels.copy() | |
| # Only do self-verifaction for labels from different sematically equivalent branches | |
| if len(equivalent_classes) <= 1: | |
| continue | |
| i = 0 | |
| while i < len(equivalent_classes): | |
| if i + 1 < len(equivalent_classes): | |
| label_a = equivalent_classes[i][0] | |
| label_b = equivalent_classes[i + 1][0] | |
| full_a = context_before[label_a] + alternative_paths[label_a] | |
| full_b = context_before[label_b] + alternative_paths[label_b] | |
| score = verify_correctness_pairwise( | |
| full_text_1=full_a, | |
| full_text_2=full_b, | |
| verification_model=verification_model, | |
| problem=problem, | |
| api=verification_api, | |
| ) | |
| if float(score[0]) < 1.0: | |
| print(f"Label {label_a} is incorrect at node {node_id}") | |
| masked_candidates[label_a] = ( | |
| masked_candidates[label_a] | |
| .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*") | |
| .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*") | |
| ) | |
| if not prev_step[label_a]: | |
| prev_step[label_a] = True | |
| if prev_step[label_a] and grace_period or not grace_period: | |
| for label_i in equivalent_classes[i]: | |
| new_labels.remove(label_i) | |
| print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)") | |
| if float(score[0]) == 1.0: | |
| prev_step[label_a] = False | |
| if float(score[1]) < 1.0: | |
| print(f"Label {label_b} is incorrect at node {node_id}") | |
| masked_candidates[label_b] = ( | |
| masked_candidates[label_b] | |
| .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*") | |
| .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*") | |
| ) | |
| if not prev_step[label_b]: | |
| prev_step[label_b] = True | |
| if prev_step[label_b] and grace_period or not grace_period: | |
| for label_i in equivalent_classes[i + 1]: | |
| new_labels.remove(label_i) | |
| print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)") | |
| if float(score[1]) == 1.0: | |
| prev_step[label_b] = False | |
| i += 2 | |
| else: | |
| break | |
| if len(new_labels) == 0: | |
| patch_start_node = node_id | |
| break | |
| selected_labels = new_labels.copy() | |
| # These are the pruned approaches with masking | |
| print(masked_candidates) | |
| masked_approaches = "\n".join( | |
| [ | |
| f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}" | |
| for label in selected_labels | |
| ] | |
| ) | |
| # These are all approaches with masking | |
| all_approaches = "\n".join( | |
| [f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()] | |
| ) | |
| default_prompt = f""" | |
| Solve the following math problem with mathematical precision and clarity. | |
| Problem: {problem} | |
| Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*). | |
| These sections may contain conceptual or computational errors. | |
| There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*. | |
| A verification step indicated that these steps are highly likely to contain errors. | |
| Potential Approaches: | |
| {masked_approaches} | |
| Your task: | |
| 1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses | |
| If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors. | |
| 2. Using the sections with special markers, identify potential errors. | |
| 3. Develop a rigorous, step-by-step solution based on sound mathematical principles | |
| 4. For uncertain regions: | |
| - Verify each step using algebraic or numerical validation | |
| - If correct, incorporate these steps with appropriate justification | |
| - If incorrect, provide clear corrections with mathematical reasoning for your changes | |
| 5. Follow a comparative approach, using the differences between approaches to identify potential errors. | |
| 6. Do not blindly follow the approaches, but rather use them to identify potential errors. | |
| Guidelines for your solution: | |
| - Begin with a strategic overview of your chosen approach | |
| - Present each mathematical step with clear notation and justification | |
| - Pay special attention to areas that were previously marked uncertain | |
| Conclude your solution with: | |
| Therefore, the final answer is: $\\boxed{{answer}}$. | |
| Solution: | |
| """ | |
| patch_prompt = f""" | |
| Solve the following mathematical problem with precision and clarity. | |
| Problem: {problem} | |
| You have been provided with several partial solution approaches that attempted to solve this problem. | |
| None of these approaches are correct, but may contain valuable insights. | |
| Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty. | |
| A verification step indicated that these steps are likely to contain errors. | |
| INSTRUCTIONS: | |
| 1. Synthesize a correct solution using insights from the previous approaches | |
| 2. Pay special attention to fixing the problematic areas marked by separators | |
| 3. Develop your solution step-by-step, showing clear mathematical reasoning | |
| 4. Focus especially on mathematical correctness in areas where previous solutions diverged | |
| 5. Present your work in a logical, sequential manner suitable for an advanced reader | |
| GUIDELINES FOR MATHEMATICAL RIGOR: | |
| 1. MAINTAIN MATHEMATICAL RIGOR | |
| - Verify that all mathematical operations follow from established principles and definitions | |
| - Ensure dimensional consistency throughout calculations | |
| - Check that algebraic manipulations preserve equality and do not introduce errors | |
| 2. CONSIDER ALTERNATIVE PERSPECTIVES | |
| - Even when approaches reach the same conclusion, examine their reasoning independently | |
| - Look for more elegant or insightful connections that may be missed across all approaches | |
| - Consider whether fundamental mathematical principles suggest a different path | |
| 3. CRITICAL VALIDATION | |
| - Test conclusions using known mathematical properties and relationships | |
| - When possible, verify results using alternative methods | |
| - Be especially cautious when all approaches agree on a result but use similar reasoning | |
| 4. USE PRECISION IN CORRECTIONS | |
| - When correcting uncertain regions, specify exactly what was incorrect and why | |
| - Provide clear mathematical justification for any changes | |
| - Ensure corrections align with standard mathematical principles and notations | |
| Previous Approaches (for reference only): | |
| {all_approaches} | |
| Your Solution: | |
| [Begin with a clear statement of your approach] | |
| [Provide detailed mathematical steps] | |
| [Ensure correct handling of complex mathematical operations] | |
| [Verify your work at key points, especially in previously problematic areas] | |
| Always conclude with: | |
| Therefore, the final answer is: $\\boxed{{answer}}$ | |
| """ | |
| if patch_start_node is not None or len(masked_candidates.keys()) == 1: | |
| print("None correct, patching") | |
| prompt = patch_prompt | |
| else: | |
| prompt = default_prompt | |
| return self_complete( | |
| verification_prompt=prompt, verification_model=verification_model, api=verification_api | |
| ), masked_candidates | |