Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Any | |
| from src.model.anchor_tree import make_anchor_tree | |
| from src.model.anchor_tree_domain import detect_tree_domain | |
| from src.model.anchor_tree_types import AnchorTree, AnchorTreeEdge, AnchorTreeNode, AnchorTreeRelation, AnchorTreeRole | |
| from src.model.anchor_types import AnchorRecord | |
| _MATH_STEP_ORDER = { | |
| "integration_by_parts_only": 0, | |
| "select_u_and_dv": 1, | |
| "derive_du_and_v": 2, | |
| "substitute_uv_minus_int_vdu": 3, | |
| "reduce_integral_complexity": 4, | |
| "repeat_if_needed": 5, | |
| "simplify_result": 6, | |
| "integration_constant": 7, | |
| "shortcut_lookup": 50, | |
| "table_reference": 51, | |
| "substitution_switch": 52, | |
| "meta_abort": 53, | |
| "wrong_symbolic_step": 54, | |
| } | |
| _CODE_STEP_ORDER = { | |
| "async_fastapi_service": 0, | |
| "typed_request_models": 1, | |
| "dependency_injection": 2, | |
| "async_handlers": 3, | |
| "validation_path": 4, | |
| "background_jobs": 5, | |
| "deployment_notes": 6, | |
| "django_view_reframe": 50, | |
| "synchronous_handler_reframe": 51, | |
| "template_rendering_branch": 52, | |
| } | |
| _QUANTIFIER_STEP_ORDER = { | |
| "universal_quantifier_scope": 0, | |
| "preserve_universal_claim": 1, | |
| "reject_existential_drift": 2, | |
| "restate_universal_conclusion": 3, | |
| "existential_witness_shift": 50, | |
| "drop_universal_scope": 51, | |
| } | |
| _PROOF_MODE_STEP_ORDER = { | |
| "proof_by_contradiction_mode": 0, | |
| "maintain_negation_assumption": 1, | |
| "derive_contradiction": 2, | |
| "discharge_negation_assumption": 3, | |
| "direct_proof_switch": 50, | |
| "constructive_reset": 51, | |
| } | |
| def _domain_root_label(domain: str | None) -> str: | |
| if domain == "math_ibp": | |
| return "integration_by_parts_only" | |
| if domain == "code_fastapi": | |
| return "async_fastapi_service" | |
| if domain == "quantifier": | |
| return "universal_quantifier_scope" | |
| if domain == "proof_mode": | |
| return "proof_by_contradiction_mode" | |
| return "observed_root" | |
| def _step_order(domain: str | None) -> dict[str, int]: | |
| if domain == "math_ibp": | |
| return _MATH_STEP_ORDER | |
| if domain == "code_fastapi": | |
| return _CODE_STEP_ORDER | |
| if domain == "quantifier": | |
| return _QUANTIFIER_STEP_ORDER | |
| if domain == "proof_mode": | |
| return _PROOF_MODE_STEP_ORDER | |
| return {} | |
| def _classify_math_label(text: str) -> str: | |
| lowered = text.lower() | |
| if "integration by parts" in lowered: | |
| return "integration_by_parts_only" | |
| if ("let u" in lowered or "u =" in lowered) and "dv" in lowered: | |
| return "select_u_and_dv" | |
| if "du" in lowered and ("v =" in lowered or " v " in f" {lowered} "): | |
| return "derive_du_and_v" | |
| if "uv" in lowered or "vdu" in lowered or "substitut" in lowered: | |
| return "substitute_uv_minus_int_vdu" | |
| if "remaining integral" in lowered or "reduce" in lowered: | |
| return "reduce_integral_complexity" | |
| if "repeat" in lowered or "again" in lowered: | |
| return "repeat_if_needed" | |
| if "+ c" in lowered or "+c" in lowered or "constant of integration" in lowered: | |
| return "integration_constant" | |
| if "simplif" in lowered: | |
| return "simplify_result" | |
| if "shortcut" in lowered: | |
| return "shortcut_lookup" | |
| if "table" in lowered: | |
| return "table_reference" | |
| if "substitution" in lowered: | |
| return "substitution_switch" | |
| if any(marker in lowered for marker in ("too hard", "challenging", "no clear path", "alternative approach")): | |
| return "meta_abort" | |
| return "math_observed_step" | |
| def _classify_code_label(text: str) -> str: | |
| lowered = text.lower() | |
| if "fastapi" in lowered and "async" in lowered: | |
| return "async_fastapi_service" | |
| if "pydantic" in lowered or "request model" in lowered or "response model" in lowered: | |
| return "typed_request_models" | |
| if "dependency injection" in lowered or "depends(" in lowered: | |
| return "dependency_injection" | |
| if "async handler" in lowered or ("async def" in lowered and "await" in lowered): | |
| return "async_handlers" | |
| if "validation" in lowered or "validate" in lowered: | |
| return "validation_path" | |
| if "background task" in lowered or "background job" in lowered: | |
| return "background_jobs" | |
| if "deploy" in lowered or "uvicorn" in lowered or "gunicorn" in lowered: | |
| return "deployment_notes" | |
| if "django" in lowered: | |
| return "django_view_reframe" | |
| if "synchronous" in lowered or "sync view" in lowered: | |
| return "synchronous_handler_reframe" | |
| if "template" in lowered or "render" in lowered: | |
| return "template_rendering_branch" | |
| return "code_observed_step" | |
| def _classify_quantifier_label(text: str) -> str: | |
| lowered = text.lower() | |
| if "restore the original universal" in lowered or "restate the universal" in lowered: | |
| return "restate_universal_conclusion" | |
| if "reject" in lowered and ("existential" in lowered or "witness" in lowered): | |
| return "reject_existential_drift" | |
| if "there exists" in lowered or "one witness" in lowered or "existential" in lowered: | |
| return "existential_witness_shift" | |
| if "drop" in lowered and "universal" in lowered: | |
| return "drop_universal_scope" | |
| if "for all" in lowered or "universal" in lowered: | |
| return "preserve_universal_claim" | |
| return "quantifier_observed_step" | |
| def _classify_proof_mode_label(text: str) -> str: | |
| lowered = text.lower() | |
| if "contradiction" in lowered and ("keep" in lowered or "return" in lowered or "mode" in lowered): | |
| return "proof_by_contradiction_mode" | |
| if "assume the negation" in lowered or "assumed negation" in lowered or "negation assumption" in lowered: | |
| return "maintain_negation_assumption" | |
| if "derive a contradiction" in lowered or "contradiction structure" in lowered: | |
| return "derive_contradiction" | |
| if "assumption was false" in lowered or "discharge the assumed negation" in lowered: | |
| return "discharge_negation_assumption" | |
| if "direct proof" in lowered: | |
| return "direct_proof_switch" | |
| if "constructive proof" in lowered or "from scratch" in lowered: | |
| return "constructive_reset" | |
| return "proof_mode_observed_step" | |
| def classify_observed_label(domain: str | None, text: str) -> str: | |
| if domain == "math_ibp": | |
| return _classify_math_label(text) | |
| if domain == "code_fastapi": | |
| return _classify_code_label(text) | |
| if domain == "quantifier": | |
| return _classify_quantifier_label(text) | |
| if domain == "proof_mode": | |
| return _classify_proof_mode_label(text) | |
| return "observed_step" | |
| def _role_for_label(label: str, source: str, is_root: bool) -> AnchorTreeRole: | |
| if is_root: | |
| return AnchorTreeRole.CONSTRAINT | |
| if label in { | |
| "shortcut_lookup", | |
| "table_reference", | |
| "substitution_switch", | |
| "wrong_symbolic_step", | |
| "django_view_reframe", | |
| "synchronous_handler_reframe", | |
| "template_rendering_branch", | |
| "existential_witness_shift", | |
| "drop_universal_scope", | |
| "direct_proof_switch", | |
| "constructive_reset", | |
| }: | |
| return AnchorTreeRole.DRIFT | |
| if label == "meta_abort": | |
| return AnchorTreeRole.META | |
| if source == "auxiliary_proposal": | |
| return AnchorTreeRole.REPAIR | |
| if source == "future_hint": | |
| return AnchorTreeRole.DERIVED | |
| return AnchorTreeRole.STEP | |
| def _score_anchor(anchor: AnchorRecord) -> float: | |
| support = float(anchor.support.detach().item()) if hasattr(anchor.support, "detach") else float(anchor.support) | |
| viability = float(anchor.viability.detach().item()) if hasattr(anchor.viability, "detach") else float(anchor.viability) | |
| return support * max(viability, 1e-6) | |
| def _make_root_node(domain: str | None, text: str, root_anchor: AnchorRecord | None) -> AnchorTreeNode: | |
| root_score = _score_anchor(root_anchor) if root_anchor is not None else 0.0 | |
| root_repr = root_anchor.repr.detach().clone() if root_anchor is not None else None | |
| root_start = int(root_anchor.start_idx) if root_anchor is not None else 0 | |
| root_end = int(root_anchor.end_idx) if root_anchor is not None else max(0, len(text.split()) - 1) | |
| return AnchorTreeNode( | |
| node_id="root", | |
| label=_domain_root_label(domain), | |
| text=text, | |
| depth=0, | |
| role=AnchorTreeRole.CONSTRAINT, | |
| source="prompt", | |
| anchor_id=None if root_anchor is None else int(root_anchor.id), | |
| span_start=root_start, | |
| span_end=root_end, | |
| repr=root_repr, | |
| score=float(root_score), | |
| ) | |
| def _make_anchor_payloads(active_anchors: list[dict[str, Any] | AnchorRecord]) -> list[dict[str, Any]]: | |
| payloads: list[dict[str, Any]] = [] | |
| for item in active_anchors: | |
| if isinstance(item, AnchorRecord): | |
| payloads.append({ | |
| "anchor": item, | |
| "text": f"anchor_{item.id}", | |
| "start": int(item.start_idx), | |
| "end": int(item.end_idx), | |
| }) | |
| else: | |
| payloads.append(item) | |
| payloads.sort(key=lambda payload: (int(payload.get("start", 0)), int(payload.get("end", 0)))) | |
| return payloads | |
| def build_observed_tree( | |
| *, | |
| text: str, | |
| active_anchors: list[dict[str, Any] | AnchorRecord], | |
| future_hint_candidates: list[dict[str, Any]], | |
| auxiliary_proposals: list[dict[str, Any]], | |
| domain: str | None = None, | |
| ) -> AnchorTree: | |
| anchor_payloads = _make_anchor_payloads(active_anchors) | |
| anchor_texts = [str(payload.get("text", "")) for payload in anchor_payloads] | |
| resolved_domain = domain or detect_tree_domain(text=text, anchor_texts=anchor_texts) | |
| root_anchor = None | |
| if anchor_payloads: | |
| root_anchor = max( | |
| (payload["anchor"] for payload in anchor_payloads if isinstance(payload.get("anchor"), AnchorRecord)), | |
| key=_score_anchor, | |
| default=None, | |
| ) | |
| root = _make_root_node(resolved_domain, text, root_anchor) | |
| nodes: list[AnchorTreeNode] = [] | |
| edges: list[AnchorTreeEdge] = [] | |
| order_map = _step_order(resolved_domain) | |
| normalized_items: list[dict[str, Any]] = [] | |
| for idx, payload in enumerate(anchor_payloads): | |
| anchor = payload.get("anchor") | |
| payload_text = str(payload.get("text", "")).strip() | |
| node_label = classify_observed_label(resolved_domain, payload_text) | |
| if payload_text and node_label != root.label: | |
| normalized_items.append( | |
| { | |
| "node_id": f"anchor_{idx}", | |
| "label": node_label, | |
| "text": payload_text, | |
| "span_start": int(payload.get("start", 0)), | |
| "span_end": int(payload.get("end", 0)), | |
| "repr": None if anchor is None else anchor.repr.detach().clone(), | |
| "score": 0.0 if anchor is None else _score_anchor(anchor), | |
| "source": "active_anchor", | |
| } | |
| ) | |
| for idx, hint in enumerate(future_hint_candidates): | |
| hint_text = str(hint.get("text", "")).strip() | |
| if not hint_text: | |
| continue | |
| normalized_items.append( | |
| { | |
| "node_id": f"hint_{idx}", | |
| "label": classify_observed_label(resolved_domain, hint_text), | |
| "text": hint_text, | |
| "span_start": int(hint.get("start", 0)), | |
| "span_end": int(hint.get("end", hint.get("start", 0))), | |
| "repr": None, | |
| "score": float(hint.get("mean_score", 0.0)), | |
| "source": "future_hint", | |
| } | |
| ) | |
| for idx, proposal in enumerate(auxiliary_proposals): | |
| proposal_text = str(proposal.get("proposal_text", "")).strip() | |
| start, end = proposal.get("proposal_span", (0, 0)) | |
| if not proposal_text: | |
| continue | |
| normalized_items.append( | |
| { | |
| "node_id": f"proposal_{idx}", | |
| "label": classify_observed_label(resolved_domain, proposal_text), | |
| "text": proposal_text, | |
| "span_start": int(start), | |
| "span_end": int(end), | |
| "repr": proposal.get("repr"), | |
| "score": float(proposal.get("proposal_score", 0.0)), | |
| "source": "auxiliary_proposal", | |
| } | |
| ) | |
| normalized_items.sort( | |
| key=lambda item: ( | |
| order_map.get(item["label"], 999), | |
| int(item["span_start"]), | |
| item["node_id"], | |
| ) | |
| ) | |
| last_progress_node_id = root.node_id | |
| for depth_idx, item in enumerate(normalized_items, start=1): | |
| label = str(item["label"]) | |
| role = _role_for_label(label, str(item["source"]), is_root=False) | |
| node = AnchorTreeNode( | |
| node_id=str(item["node_id"]), | |
| label=label, | |
| text=str(item["text"]), | |
| depth=depth_idx if role not in {AnchorTreeRole.DRIFT, AnchorTreeRole.META} else depth_idx, | |
| role=role, | |
| source=str(item["source"]), | |
| span_start=int(item["span_start"]), | |
| span_end=int(item["span_end"]), | |
| repr=item.get("repr"), | |
| score=float(item["score"]), | |
| drift_flag=role in {AnchorTreeRole.DRIFT, AnchorTreeRole.META}, | |
| ) | |
| nodes.append(node) | |
| if role in {AnchorTreeRole.DRIFT, AnchorTreeRole.META}: | |
| parent_id = last_progress_node_id | |
| relation = AnchorTreeRelation.ALTERNATIVE_TO | |
| else: | |
| parent_id = last_progress_node_id | |
| relation = AnchorTreeRelation.EXPECTED_NEXT if parent_id != root.node_id else AnchorTreeRelation.CHILD | |
| last_progress_node_id = node.node_id | |
| edges.append(AnchorTreeEdge(parent_id=parent_id, child_id=node.node_id, relation=relation, score=float(node.score))) | |
| tree = make_anchor_tree( | |
| tree_id=f"observed_{resolved_domain or 'unknown'}", | |
| root=root, | |
| nodes=nodes, | |
| edges=edges, | |
| domain=resolved_domain or "unknown", | |
| source_kind="observed", | |
| ) | |
| tree.metadata["input_text"] = text | |
| tree.metadata["anchor_count"] = len(anchor_payloads) | |
| tree.metadata["future_hint_count"] = len(future_hint_candidates) | |
| tree.metadata["auxiliary_proposal_count"] = len(auxiliary_proposals) | |
| return tree | |