fix: canonical.py benchmark regression fixes
Browse files
tensegrity/pipeline/canonical.py
CHANGED
|
@@ -150,6 +150,10 @@ class CanonicalPipeline:
|
|
| 150 |
llm_evidence_weight: float = 1.0,
|
| 151 |
# Persistent episodic recall enters as a memory-evidence channel.
|
| 152 |
memory_evidence_weight: float = 0.75,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
feedback_learning_rate: float = 1.0,
|
| 154 |
persistent_state_path: Optional[str] = None,
|
| 155 |
):
|
|
@@ -161,6 +165,7 @@ class CanonicalPipeline:
|
|
| 161 |
self.max_hypotheses = max(2, int(max_hypotheses))
|
| 162 |
self.llm_evidence_weight = float(llm_evidence_weight)
|
| 163 |
self.memory_evidence_weight = float(memory_evidence_weight)
|
|
|
|
| 164 |
self.feedback_learning_rate = float(feedback_learning_rate)
|
| 165 |
self.persistent_state_path = persistent_state_path
|
| 166 |
|
|
@@ -564,6 +569,7 @@ class CanonicalPipeline:
|
|
| 564 |
# if causal tension is high (the controller wires this internally).
|
| 565 |
initial_perception = self.ingest_prompt(sample.prompt)
|
| 566 |
memory_scores = self._memory_choice_scores(sample)
|
|
|
|
| 567 |
|
| 568 |
trace: List[IterationStep] = []
|
| 569 |
converged = False
|
|
@@ -601,12 +607,14 @@ class CanonicalPipeline:
|
|
| 601 |
fz = self._znorm(falsify)
|
| 602 |
lz = self._znorm(linguistic)
|
| 603 |
mz = self._znorm(memory_scores)
|
|
|
|
| 604 |
log_lik_falsify = self.falsify_update_strength * fz
|
| 605 |
log_post = (
|
| 606 |
np.log(np.maximum(old_belief, 1e-12))
|
| 607 |
+ log_lik_falsify
|
| 608 |
+ self.llm_evidence_weight * lz
|
| 609 |
+ self.memory_evidence_weight * mz
|
|
|
|
| 610 |
+ np.log(np.maximum(energy_post, 1e-12))
|
| 611 |
)
|
| 612 |
log_post -= log_post.max()
|
|
@@ -697,6 +705,47 @@ class CanonicalPipeline:
|
|
| 697 |
norm = np.linalg.norm(v)
|
| 698 |
return v / norm if norm > 1e-10 else v
|
| 699 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
def _memory_choice_scores(self, sample: TaskSample) -> np.ndarray:
|
| 701 |
"""Retrieve prior successful episodes and score choices by similarity.
|
| 702 |
|
|
|
|
| 150 |
llm_evidence_weight: float = 1.0,
|
| 151 |
# Persistent episodic recall enters as a memory-evidence channel.
|
| 152 |
memory_evidence_weight: float = 0.75,
|
| 153 |
+
# SBERT sentence similarity enters as a semantic-evidence channel.
|
| 154 |
+
# This is the strongest signal source: it compares the prompt against
|
| 155 |
+
# each (prompt+choice) concatenation using frozen sentence embeddings.
|
| 156 |
+
sbert_evidence_weight: float = 0.8,
|
| 157 |
feedback_learning_rate: float = 1.0,
|
| 158 |
persistent_state_path: Optional[str] = None,
|
| 159 |
):
|
|
|
|
| 165 |
self.max_hypotheses = max(2, int(max_hypotheses))
|
| 166 |
self.llm_evidence_weight = float(llm_evidence_weight)
|
| 167 |
self.memory_evidence_weight = float(memory_evidence_weight)
|
| 168 |
+
self.sbert_evidence_weight = float(sbert_evidence_weight)
|
| 169 |
self.feedback_learning_rate = float(feedback_learning_rate)
|
| 170 |
self.persistent_state_path = persistent_state_path
|
| 171 |
|
|
|
|
| 569 |
# if causal tension is high (the controller wires this internally).
|
| 570 |
initial_perception = self.ingest_prompt(sample.prompt)
|
| 571 |
memory_scores = self._memory_choice_scores(sample)
|
| 572 |
+
sbert_scores = self._sbert_choice_scores(sample)
|
| 573 |
|
| 574 |
trace: List[IterationStep] = []
|
| 575 |
converged = False
|
|
|
|
| 607 |
fz = self._znorm(falsify)
|
| 608 |
lz = self._znorm(linguistic)
|
| 609 |
mz = self._znorm(memory_scores)
|
| 610 |
+
sz = self._znorm(sbert_scores)
|
| 611 |
log_lik_falsify = self.falsify_update_strength * fz
|
| 612 |
log_post = (
|
| 613 |
np.log(np.maximum(old_belief, 1e-12))
|
| 614 |
+ log_lik_falsify
|
| 615 |
+ self.llm_evidence_weight * lz
|
| 616 |
+ self.memory_evidence_weight * mz
|
| 617 |
+
+ self.sbert_evidence_weight * sz
|
| 618 |
+ np.log(np.maximum(energy_post, 1e-12))
|
| 619 |
)
|
| 620 |
log_post -= log_post.max()
|
|
|
|
| 705 |
norm = np.linalg.norm(v)
|
| 706 |
return v / norm if norm > 1e-10 else v
|
| 707 |
|
| 708 |
+
def _sbert_choice_scores(self, sample: TaskSample) -> np.ndarray:
|
| 709 |
+
"""Score choices by SBERT sentence-level cosine similarity.
|
| 710 |
+
|
| 711 |
+
This is the strongest semantic signal: it compares the prompt against
|
| 712 |
+
each choice using frozen sentence embeddings from a pretrained SBERT
|
| 713 |
+
model. Unlike the NGC falsification path, this signal is NOT destroyed
|
| 714 |
+
by the random FHRR→obs projection and directly measures semantic
|
| 715 |
+
relatedness in the original embedding space.
|
| 716 |
+
"""
|
| 717 |
+
n = len(sample.choices)
|
| 718 |
+
scores = np.zeros(n, dtype=np.float64)
|
| 719 |
+
if n == 0:
|
| 720 |
+
return scores
|
| 721 |
+
|
| 722 |
+
field = self.controller.agent.field
|
| 723 |
+
features = field.encoder.features
|
| 724 |
+
# Try to get the SBERT model from the semantic codebook
|
| 725 |
+
getter = getattr(features, "get_sbert_model", None)
|
| 726 |
+
sbert = getter() if callable(getter) else None
|
| 727 |
+
if sbert is None:
|
| 728 |
+
return scores
|
| 729 |
+
|
| 730 |
+
try:
|
| 731 |
+
texts = [sample.prompt] + [
|
| 732 |
+
f"{sample.prompt} {c}" for c in sample.choices
|
| 733 |
+
]
|
| 734 |
+
embs = sbert.encode(texts, show_progress_bar=False)
|
| 735 |
+
pe = embs[0]
|
| 736 |
+
pn = float(np.linalg.norm(pe))
|
| 737 |
+
if pn < 1e-8:
|
| 738 |
+
return scores
|
| 739 |
+
for i in range(n):
|
| 740 |
+
ce = embs[i + 1]
|
| 741 |
+
cn = float(np.linalg.norm(ce))
|
| 742 |
+
if cn > 1e-8:
|
| 743 |
+
scores[i] = float(np.dot(pe, ce) / (pn * cn))
|
| 744 |
+
except Exception as e:
|
| 745 |
+
logger.debug("SBERT choice scoring failed: %s", e)
|
| 746 |
+
|
| 747 |
+
return scores
|
| 748 |
+
|
| 749 |
def _memory_choice_scores(self, sample: TaskSample) -> np.ndarray:
|
| 750 |
"""Retrieve prior successful episodes and score choices by similarity.
|
| 751 |
|