from __future__ import annotations from typing import Any from src.model.anchor_tree import make_anchor_tree from src.model.anchor_tree_domain import detect_tree_domain from src.model.anchor_tree_types import AnchorTree, AnchorTreeEdge, AnchorTreeNode, AnchorTreeRelation, AnchorTreeRole from src.model.anchor_types import AnchorRecord _MATH_STEP_ORDER = { "integration_by_parts_only": 0, "select_u_and_dv": 1, "derive_du_and_v": 2, "substitute_uv_minus_int_vdu": 3, "reduce_integral_complexity": 4, "repeat_if_needed": 5, "simplify_result": 6, "integration_constant": 7, "shortcut_lookup": 50, "table_reference": 51, "substitution_switch": 52, "meta_abort": 53, "wrong_symbolic_step": 54, } _CODE_STEP_ORDER = { "async_fastapi_service": 0, "typed_request_models": 1, "dependency_injection": 2, "async_handlers": 3, "validation_path": 4, "background_jobs": 5, "deployment_notes": 6, "django_view_reframe": 50, "synchronous_handler_reframe": 51, "template_rendering_branch": 52, } _QUANTIFIER_STEP_ORDER = { "universal_quantifier_scope": 0, "preserve_universal_claim": 1, "reject_existential_drift": 2, "restate_universal_conclusion": 3, "existential_witness_shift": 50, "drop_universal_scope": 51, } _PROOF_MODE_STEP_ORDER = { "proof_by_contradiction_mode": 0, "maintain_negation_assumption": 1, "derive_contradiction": 2, "discharge_negation_assumption": 3, "direct_proof_switch": 50, "constructive_reset": 51, } def _domain_root_label(domain: str | None) -> str: if domain == "math_ibp": return "integration_by_parts_only" if domain == "code_fastapi": return "async_fastapi_service" if domain == "quantifier": return "universal_quantifier_scope" if domain == "proof_mode": return "proof_by_contradiction_mode" return "observed_root" def _step_order(domain: str | None) -> dict[str, int]: if domain == "math_ibp": return _MATH_STEP_ORDER if domain == "code_fastapi": return _CODE_STEP_ORDER if domain == "quantifier": return _QUANTIFIER_STEP_ORDER if domain == "proof_mode": return _PROOF_MODE_STEP_ORDER return {} def _classify_math_label(text: str) -> str: lowered = text.lower() if "integration by parts" in lowered: return "integration_by_parts_only" if ("let u" in lowered or "u =" in lowered) and "dv" in lowered: return "select_u_and_dv" if "du" in lowered and ("v =" in lowered or " v " in f" {lowered} "): return "derive_du_and_v" if "uv" in lowered or "vdu" in lowered or "substitut" in lowered: return "substitute_uv_minus_int_vdu" if "remaining integral" in lowered or "reduce" in lowered: return "reduce_integral_complexity" if "repeat" in lowered or "again" in lowered: return "repeat_if_needed" if "+ c" in lowered or "+c" in lowered or "constant of integration" in lowered: return "integration_constant" if "simplif" in lowered: return "simplify_result" if "shortcut" in lowered: return "shortcut_lookup" if "table" in lowered: return "table_reference" if "substitution" in lowered: return "substitution_switch" if any(marker in lowered for marker in ("too hard", "challenging", "no clear path", "alternative approach")): return "meta_abort" return "math_observed_step" def _classify_code_label(text: str) -> str: lowered = text.lower() if "fastapi" in lowered and "async" in lowered: return "async_fastapi_service" if "pydantic" in lowered or "request model" in lowered or "response model" in lowered: return "typed_request_models" if "dependency injection" in lowered or "depends(" in lowered: return "dependency_injection" if "async handler" in lowered or ("async def" in lowered and "await" in lowered): return "async_handlers" if "validation" in lowered or "validate" in lowered: return "validation_path" if "background task" in lowered or "background job" in lowered: return "background_jobs" if "deploy" in lowered or "uvicorn" in lowered or "gunicorn" in lowered: return "deployment_notes" if "django" in lowered: return "django_view_reframe" if "synchronous" in lowered or "sync view" in lowered: return "synchronous_handler_reframe" if "template" in lowered or "render" in lowered: return "template_rendering_branch" return "code_observed_step" def _classify_quantifier_label(text: str) -> str: lowered = text.lower() if "restore the original universal" in lowered or "restate the universal" in lowered: return "restate_universal_conclusion" if "reject" in lowered and ("existential" in lowered or "witness" in lowered): return "reject_existential_drift" if "there exists" in lowered or "one witness" in lowered or "existential" in lowered: return "existential_witness_shift" if "drop" in lowered and "universal" in lowered: return "drop_universal_scope" if "for all" in lowered or "universal" in lowered: return "preserve_universal_claim" return "quantifier_observed_step" def _classify_proof_mode_label(text: str) -> str: lowered = text.lower() if "contradiction" in lowered and ("keep" in lowered or "return" in lowered or "mode" in lowered): return "proof_by_contradiction_mode" if "assume the negation" in lowered or "assumed negation" in lowered or "negation assumption" in lowered: return "maintain_negation_assumption" if "derive a contradiction" in lowered or "contradiction structure" in lowered: return "derive_contradiction" if "assumption was false" in lowered or "discharge the assumed negation" in lowered: return "discharge_negation_assumption" if "direct proof" in lowered: return "direct_proof_switch" if "constructive proof" in lowered or "from scratch" in lowered: return "constructive_reset" return "proof_mode_observed_step" def classify_observed_label(domain: str | None, text: str) -> str: if domain == "math_ibp": return _classify_math_label(text) if domain == "code_fastapi": return _classify_code_label(text) if domain == "quantifier": return _classify_quantifier_label(text) if domain == "proof_mode": return _classify_proof_mode_label(text) return "observed_step" def _role_for_label(label: str, source: str, is_root: bool) -> AnchorTreeRole: if is_root: return AnchorTreeRole.CONSTRAINT if label in { "shortcut_lookup", "table_reference", "substitution_switch", "wrong_symbolic_step", "django_view_reframe", "synchronous_handler_reframe", "template_rendering_branch", "existential_witness_shift", "drop_universal_scope", "direct_proof_switch", "constructive_reset", }: return AnchorTreeRole.DRIFT if label == "meta_abort": return AnchorTreeRole.META if source == "auxiliary_proposal": return AnchorTreeRole.REPAIR if source == "future_hint": return AnchorTreeRole.DERIVED return AnchorTreeRole.STEP def _score_anchor(anchor: AnchorRecord) -> float: support = float(anchor.support.detach().item()) if hasattr(anchor.support, "detach") else float(anchor.support) viability = float(anchor.viability.detach().item()) if hasattr(anchor.viability, "detach") else float(anchor.viability) return support * max(viability, 1e-6) def _make_root_node(domain: str | None, text: str, root_anchor: AnchorRecord | None) -> AnchorTreeNode: root_score = _score_anchor(root_anchor) if root_anchor is not None else 0.0 root_repr = root_anchor.repr.detach().clone() if root_anchor is not None else None root_start = int(root_anchor.start_idx) if root_anchor is not None else 0 root_end = int(root_anchor.end_idx) if root_anchor is not None else max(0, len(text.split()) - 1) return AnchorTreeNode( node_id="root", label=_domain_root_label(domain), text=text, depth=0, role=AnchorTreeRole.CONSTRAINT, source="prompt", anchor_id=None if root_anchor is None else int(root_anchor.id), span_start=root_start, span_end=root_end, repr=root_repr, score=float(root_score), ) def _make_anchor_payloads(active_anchors: list[dict[str, Any] | AnchorRecord]) -> list[dict[str, Any]]: payloads: list[dict[str, Any]] = [] for item in active_anchors: if isinstance(item, AnchorRecord): payloads.append({ "anchor": item, "text": f"anchor_{item.id}", "start": int(item.start_idx), "end": int(item.end_idx), }) else: payloads.append(item) payloads.sort(key=lambda payload: (int(payload.get("start", 0)), int(payload.get("end", 0)))) return payloads def build_observed_tree( *, text: str, active_anchors: list[dict[str, Any] | AnchorRecord], future_hint_candidates: list[dict[str, Any]], auxiliary_proposals: list[dict[str, Any]], domain: str | None = None, ) -> AnchorTree: anchor_payloads = _make_anchor_payloads(active_anchors) anchor_texts = [str(payload.get("text", "")) for payload in anchor_payloads] resolved_domain = domain or detect_tree_domain(text=text, anchor_texts=anchor_texts) root_anchor = None if anchor_payloads: root_anchor = max( (payload["anchor"] for payload in anchor_payloads if isinstance(payload.get("anchor"), AnchorRecord)), key=_score_anchor, default=None, ) root = _make_root_node(resolved_domain, text, root_anchor) nodes: list[AnchorTreeNode] = [] edges: list[AnchorTreeEdge] = [] order_map = _step_order(resolved_domain) normalized_items: list[dict[str, Any]] = [] for idx, payload in enumerate(anchor_payloads): anchor = payload.get("anchor") payload_text = str(payload.get("text", "")).strip() node_label = classify_observed_label(resolved_domain, payload_text) if payload_text and node_label != root.label: normalized_items.append( { "node_id": f"anchor_{idx}", "label": node_label, "text": payload_text, "span_start": int(payload.get("start", 0)), "span_end": int(payload.get("end", 0)), "repr": None if anchor is None else anchor.repr.detach().clone(), "score": 0.0 if anchor is None else _score_anchor(anchor), "source": "active_anchor", } ) for idx, hint in enumerate(future_hint_candidates): hint_text = str(hint.get("text", "")).strip() if not hint_text: continue normalized_items.append( { "node_id": f"hint_{idx}", "label": classify_observed_label(resolved_domain, hint_text), "text": hint_text, "span_start": int(hint.get("start", 0)), "span_end": int(hint.get("end", hint.get("start", 0))), "repr": None, "score": float(hint.get("mean_score", 0.0)), "source": "future_hint", } ) for idx, proposal in enumerate(auxiliary_proposals): proposal_text = str(proposal.get("proposal_text", "")).strip() start, end = proposal.get("proposal_span", (0, 0)) if not proposal_text: continue normalized_items.append( { "node_id": f"proposal_{idx}", "label": classify_observed_label(resolved_domain, proposal_text), "text": proposal_text, "span_start": int(start), "span_end": int(end), "repr": proposal.get("repr"), "score": float(proposal.get("proposal_score", 0.0)), "source": "auxiliary_proposal", } ) normalized_items.sort( key=lambda item: ( order_map.get(item["label"], 999), int(item["span_start"]), item["node_id"], ) ) last_progress_node_id = root.node_id for depth_idx, item in enumerate(normalized_items, start=1): label = str(item["label"]) role = _role_for_label(label, str(item["source"]), is_root=False) node = AnchorTreeNode( node_id=str(item["node_id"]), label=label, text=str(item["text"]), depth=depth_idx if role not in {AnchorTreeRole.DRIFT, AnchorTreeRole.META} else depth_idx, role=role, source=str(item["source"]), span_start=int(item["span_start"]), span_end=int(item["span_end"]), repr=item.get("repr"), score=float(item["score"]), drift_flag=role in {AnchorTreeRole.DRIFT, AnchorTreeRole.META}, ) nodes.append(node) if role in {AnchorTreeRole.DRIFT, AnchorTreeRole.META}: parent_id = last_progress_node_id relation = AnchorTreeRelation.ALTERNATIVE_TO else: parent_id = last_progress_node_id relation = AnchorTreeRelation.EXPECTED_NEXT if parent_id != root.node_id else AnchorTreeRelation.CHILD last_progress_node_id = node.node_id edges.append(AnchorTreeEdge(parent_id=parent_id, child_id=node.node_id, relation=relation, score=float(node.score))) tree = make_anchor_tree( tree_id=f"observed_{resolved_domain or 'unknown'}", root=root, nodes=nodes, edges=edges, domain=resolved_domain or "unknown", source_kind="observed", ) tree.metadata["input_text"] = text tree.metadata["anchor_count"] = len(anchor_payloads) tree.metadata["future_hint_count"] = len(future_hint_candidates) tree.metadata["auxiliary_proposal_count"] = len(auxiliary_proposals) return tree