v2 / src /planning_quality.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Planning Quality Score (PQS) — multi-dimensional assessment.
Motivation:
When we steer planning, a drop in trigger count could mean:
(a) The model truly lost planning capability, OR
(b) The model bypassed the trigger words while still planning internally.
We need to distinguish these by measuring deeper quality signals.
Four dimensions:
Q1. Structural Depth — sequential structure + forward-references
Q2. Strategy Diversity — explicit strategy naming + comparison
Q3. Long-Range Coherence — later text referring back to plan declaration
Q4. Premature Execution — how early the first computation appears (NEGATIVE signal)
Combined:
PQS = 0.25 × Q1 + 0.25 × Q2 + 0.25 × Q3 + 0.25 × (1 - Q4)
"""
import re
from typing import Dict, List
# ============================================================
# Q1. Structural Depth
# ============================================================
_ORDINAL_PATTERNS = [
r"(?i)\bstep\s+\d+\b",
r"(?i)\bfirst(ly)?[,.]",
r"(?i)\bsecond(ly)?[,.]",
r"(?i)\bthird(ly)?[,.]",
r"(?i)\bnext[,.]",
r"(?i)\bthen[,.]",
r"(?i)\bfinally[,.]",
r"(?i)\blast(ly)?[,.]",
]
_FORWARD_REF_PATTERNS = [
r"(?i)\bwe\s+will\s+need\s+to\b",
r"(?i)\blater\s+(we'?ll|we\s+will|i'?ll|i\s+will)\b",
r"(?i)\bthis\s+will\s+help\s+(us|me)\s+later\b",
r"(?i)\bin\s+a\s+(later|subsequent)\s+step\b",
r"(?i)\b(we|i)'?ll\s+(come|get)\s+back\s+to\s+(this|that)\b",
r"(?i)\bfor\s+now[,.]",
r"(?i)\blater\s+on\b",
]
def compute_q1_structural_depth(text: str) -> float:
"""
Structural depth score in [0, 1].
Combines:
- Ordinal phrase density (step 1, first, next, ...)
- Forward-reference density (we'll need to, later we'll, ...)
- Completeness of ordinal sequence (do we see first→next→finally?)
"""
n_tokens = max(len(text.split()), 1)
# Density of ordinal phrases
n_ordinal = sum(len(re.findall(p, text)) for p in _ORDINAL_PATTERNS)
ordinal_density = min(n_ordinal / (n_tokens / 100), 1.0) # up to 1 per 100 words = full score
# Forward references
n_forward = sum(len(re.findall(p, text)) for p in _FORWARD_REF_PATTERNS)
forward_density = min(n_forward / (n_tokens / 200), 1.0) # 1 per 200 words = full score
# Ordinal completeness: does the text contain at least 2 distinct ordinal markers?
distinct_ords = set()
for ord_label in ["first", "second", "third", "next", "then", "finally", "lastly"]:
if re.search(rf"(?i)\b{ord_label}\b", text):
distinct_ords.add(ord_label)
completeness = min(len(distinct_ords) / 3.0, 1.0) # 3+ distinct = full score
return 0.4 * ordinal_density + 0.3 * forward_density + 0.3 * completeness
# ============================================================
# Q2. Strategy Diversity
# ============================================================
_STRATEGY_TERMS = [
# proof techniques
r"(?i)\binduction\b", r"(?i)\bcontradiction\b", r"(?i)\bcontrapositive\b",
r"(?i)\bconstructive\s+proof\b", r"(?i)\bpigeonhole\b",
# algebraic
r"(?i)\bsubstitution\b", r"(?i)\belimination\b", r"(?i)\bfactoring\b",
r"(?i)\bcompleting\s+the\s+square\b",
# case-based
r"(?i)\bcase\s+analysis\b", r"(?i)\bby\s+cases\b",
# geometric / graphical
r"(?i)\bcoordinate\s+geometry\b", r"(?i)\bgraphing\b",
# combinatorial / number-theoretic
r"(?i)\bgenerating\s+function\b", r"(?i)\bmodular\s+arithmetic\b",
r"(?i)\bchinese\s+remainder\s+theorem\b",
# generic
r"(?i)\bwithout\s+loss\s+of\s+generality\b", r"(?i)\bwlog\b",
r"(?i)\bworking?\s+backward(s)?\b",
]
_STRATEGY_COMPARISON_PATTERNS = [
r"(?i)\banother\s+(way|approach|method)\b",
r"(?i)\balternatively[,.]",
r"(?i)\bwe\s+could\s+(also|instead)\b",
r"(?i)\bone\s+approach\s+(is|would\s+be)\b",
r"(?i)\ba\s+different\s+(way|approach|method)\b",
]
def compute_q2_strategy_diversity(text: str) -> float:
"""
Strategy diversity score in [0, 1].
Combines:
- Unique strategy terms mentioned
- Strategy comparison (only in first 20% of CoT — else it's backtracking)
"""
if not text.strip():
return 0.0
length = len(text)
first_20pct = text[: int(length * 0.20)]
# Unique strategy terms
strategies_found = set()
for pat in _STRATEGY_TERMS:
if re.search(pat, text):
strategies_found.add(pat)
strategy_unique = min(len(strategies_found) / 3.0, 1.0) # 3+ distinct strategies = full
# Strategy comparison in first 20% (planning-phase comparison, not monitoring backtrack)
n_compare_early = sum(len(re.findall(p, first_20pct)) for p in _STRATEGY_COMPARISON_PATTERNS)
compare_score = min(n_compare_early / 2.0, 1.0) # 2+ in first 20% = full
# Strategy term density
n_strategy_total = sum(len(re.findall(p, text)) for p in _STRATEGY_TERMS)
n_tokens = max(len(text.split()), 1)
density = min(n_strategy_total / (n_tokens / 300), 1.0)
return 0.4 * strategy_unique + 0.3 * compare_score + 0.3 * density
# ============================================================
# Q3. Long-Range Coherence
# ============================================================
_BACK_REFERENCE_PATTERNS = [
r"(?i)\bas\s+(planned|outlined|mentioned)\b",
r"(?i)\bfollowing\s+(the|my|our)\s+(plan|approach|strategy)\b",
r"(?i)\bas\s+i\s+(planned|said|outlined|mentioned)\s+(earlier|before|above)\b",
r"(?i)\bas\s+we\s+(planned|said|outlined)\s+(earlier|before|above)\b",
r"(?i)\bgoing\s+back\s+to\s+(step\s+\d+|my\s+plan)\b",
r"(?i)\brecall(ing)?\s+(that|my|the)\s+(plan|approach|strategy)\b",
r"(?i)\bper\s+(my|the)\s+(plan|strategy)\b",
]
def compute_q3_long_range_coherence(text: str) -> float:
"""
Long-range coherence score in [0, 1].
Measure back-references to early planning, which is a signature of structured planning.
"""
if not text.strip():
return 0.0
length = len(text)
if length < 200:
return 0.0
# Back-references must appear in the second half of CoT
second_half = text[length // 2:]
n_back = sum(len(re.findall(p, second_half)) for p in _BACK_REFERENCE_PATTERNS)
score = min(n_back / 2.0, 1.0) # 2+ back-references in second half = full
return score
# ============================================================
# Q4. Premature Execution (NEGATIVE signal)
# ============================================================
def compute_q4_premature_execution(text: str) -> float:
"""
Premature execution score in [0, 1]. Higher = more premature (BAD for planning).
Measure: position of the first equation/computation as a fraction of total length.
- If first computation appears in first 5% of text → strong premature (→ 1.0)
- If it appears in first 20% → still premature (→ ~0.6)
- If only after 30% → not premature (→ 0.0)
"""
if not text.strip():
return 0.0
# Find first occurrence of a computation pattern
computation_patterns = [
r"[-+]?\d+\s*[+\-*/]\s*\d+\s*=", # "3+5="
r"[a-zA-Z]\s*=\s*\d", # "x = 5"
r"\d+\s*=\s*\d+", # "10 = 10"
r"\$.*?=.*?\$", # LaTeX: "$...=...$"
]
first_pos = len(text) # default: no computation
for p in computation_patterns:
m = re.search(p, text)
if m and m.start() < first_pos:
first_pos = m.start()
if first_pos >= len(text):
return 0.0 # no computation found, not premature
ratio = first_pos / len(text)
# Map: ratio 0-0.05 → score 1.0 (very premature)
# ratio 0.05-0.20 → score 0.6 → 0.2 (premature)
# ratio 0.20+ → score 0.0 (not premature)
if ratio < 0.05:
return 1.0
elif ratio < 0.20:
return 1.0 - (ratio - 0.05) / 0.15 * 0.8 # linear 1.0 → 0.2
elif ratio < 0.30:
return 0.2 - (ratio - 0.20) / 0.10 * 0.2 # linear 0.2 → 0.0
else:
return 0.0
# ============================================================
# Combined PQS
# ============================================================
def compute_pqs(text: str) -> Dict:
"""Compute all 4 dimensions and combined PQS."""
q1 = compute_q1_structural_depth(text)
q2 = compute_q2_strategy_diversity(text)
q3 = compute_q3_long_range_coherence(text)
q4 = compute_q4_premature_execution(text)
pqs = 0.25 * q1 + 0.25 * q2 + 0.25 * q3 + 0.25 * (1.0 - q4)
return {
"q1_structural_depth": float(q1),
"q2_strategy_diversity": float(q2),
"q3_long_range_coherence": float(q3),
"q4_premature_execution": float(q4),
"pqs": float(pqs),
}