morphism / retrieval.py
acb's picture
Upload retrieval.py
732bb64 verified
"""
Retrieval modes for EEG semantic decoding.
Each mode is a callable class that takes:
embedding_index: EmbeddingIndex
nexus_conn: sqlite3 connection
semantic_embedding: torch.Tensor (the current predicted text embedding)
And returns:
list of strings (lines to print), or empty list (suppress output this frame)
Drop into EEGSemanticProcessor by replacing the find_similar_messages + _print_unique_lines
path with: mode.step(semantic_embedding) -> lines
"""
import numpy as np
import torch
import hashlib
import random
from collections import deque
def fix_encoding(s):
if not s:
return s
if isinstance(s, str):
b = s.encode('utf-8', 'surrogateescape')
else:
b = s
fixed = b.decode('utf-8', 'replace')
if 'ì' in s or 'í' in s or 'ï' in s:
return ""
return fixed
def _retrieve(embedding_index, nexus_conn, embedding_np, k=64, assistant_only=False):
"""Shared retrieval helper. Returns list of (content, distance) tuples."""
if len(embedding_np.shape) == 1:
embedding_np = embedding_np.reshape(1, -1)
distances, indices = embedding_index.search(embedding_np, k)
distances = distances.flatten()
indices = indices.flatten()
cursor = nexus_conn.cursor()
query = "SELECT content FROM messages WHERE id = ?"
if assistant_only:
query += " AND role = 'assistant'"
results = []
for msg_id, dist in zip(indices, distances):
cursor.execute(query, (int(msg_id),))
row = cursor.fetchone()
if row and row[0]:
results.append((row[0], float(dist)))
return results
def _lines_from_messages(messages, max_lines=60):
"""Extract individual lines from message contents, deduplicated."""
lines = []
seen = set()
for content in messages:
for line in content.splitlines():
line = line.strip()
if not line:
continue
line = fix_encoding(line)
if not line:
continue
if line not in seen:
seen.add(line)
lines.append(line)
if len(lines) >= max_lines:
return lines
return lines
class FloodMode:
"""
Original behavior: retrieve k candidates, sample, deduplicate against
recent windows. Fast, noisy, good for raw stream-of-consciousness.
"""
def __init__(self, embedding_index, nexus_conn, search_k=180, final_k=90,
sample_size=42, last_n=3):
self.embedding_index = embedding_index
self.nexus_conn = nexus_conn
self.search_k = search_k
self.final_k = final_k
self.sample_size = sample_size
self.previous_sets = deque(maxlen=last_n)
def step(self, semantic_embedding):
emb_np = semantic_embedding.detach().cpu().numpy()
results = _retrieve(self.embedding_index, self.nexus_conn, emb_np,
k=self.search_k)
messages = [content for content, _ in results[:self.final_k]]
if not messages:
return []
sample = random.sample(messages, min(self.sample_size, len(messages)))
current_lines = set()
for msg in sample:
for line in msg.splitlines():
line = line.strip()
if line:
current_lines.add(line)
unique = current_lines.copy()
for prev in self.previous_sets:
unique -= prev
self.previous_sets.append(current_lines)
unique = [l for l in map(fix_encoding, unique) if l]
return sorted(unique)
class DriftMode:
"""
Emit output only when the semantic pointer moves significantly.
Retrieves based on the *direction* of movement (current - previous),
added to the current position. This amplifies whatever the signal
is shifting toward.
Parameters:
move_threshold: minimum cosine distance between consecutive
embeddings to trigger output
amplify: how much to weight the delta (1.0 = pure direction,
0.0 = pure position)
search_k: candidates to retrieve
cooldown: minimum steps between outputs
"""
def __init__(self, embedding_index, nexus_conn, search_k=64,
move_threshold=0.05, amplify=0.5, cooldown=3, max_lines=30):
self.embedding_index = embedding_index
self.nexus_conn = nexus_conn
self.search_k = search_k
self.move_threshold = move_threshold
self.amplify = amplify
self.cooldown = cooldown
self.max_lines = max_lines
self.prev_embedding = None
self.steps_since_emit = 0
self.prev_lines = set()
def step(self, semantic_embedding):
emb_np = semantic_embedding.detach().cpu().numpy().flatten()
# Normalize
norm = np.linalg.norm(emb_np)
if norm > 0:
emb_normed = emb_np / norm
else:
emb_normed = emb_np
self.steps_since_emit += 1
if self.prev_embedding is None:
self.prev_embedding = emb_normed
return []
# Compute movement
cos_sim = np.dot(emb_normed, self.prev_embedding)
cos_dist = 1.0 - cos_sim
if cos_dist < self.move_threshold or self.steps_since_emit < self.cooldown:
return []
# Direction of movement
delta = emb_normed - self.prev_embedding
delta_norm = np.linalg.norm(delta)
if delta_norm > 0:
delta = delta / delta_norm
# Query = current position + amplified direction
query = emb_normed + self.amplify * delta
query_norm = np.linalg.norm(query)
if query_norm > 0:
query = query / query_norm
self.prev_embedding = emb_normed
self.steps_since_emit = 0
results = _retrieve(self.embedding_index, self.nexus_conn,
query.reshape(1, -1), k=self.search_k)
messages = [content for content, _ in results]
lines = _lines_from_messages(messages, self.max_lines)
# Remove lines seen in previous emission
lines = [l for l in lines if l not in self.prev_lines]
self.prev_lines = set(lines)
return lines
class FocusMode:
"""
Maintain an exponential moving average of embeddings. Only emit
when the centroid shifts enough. Surfaces the persistent underlying
theme rather than moment-to-moment noise.
Parameters:
alpha: EMA smoothing factor (lower = smoother, more stable)
shift_threshold: minimum cosine distance of centroid movement to emit
search_k: candidates to retrieve
top_n: how many top results to show (closest to centroid)
"""
def __init__(self, embedding_index, nexus_conn, search_k=48,
alpha=0.15, shift_threshold=0.02, top_n=20, max_lines=25):
self.embedding_index = embedding_index
self.nexus_conn = nexus_conn
self.search_k = search_k
self.alpha = alpha
self.shift_threshold = shift_threshold
self.top_n = top_n
self.max_lines = max_lines
self.centroid = None
self.last_emit_centroid = None
self.prev_lines = set()
def step(self, semantic_embedding):
emb_np = semantic_embedding.detach().cpu().numpy().flatten()
norm = np.linalg.norm(emb_np)
if norm > 0:
emb_normed = emb_np / norm
else:
emb_normed = emb_np
# Update EMA centroid
if self.centroid is None:
self.centroid = emb_normed.copy()
self.last_emit_centroid = emb_normed.copy()
return []
self.centroid = self.alpha * emb_normed + (1.0 - self.alpha) * self.centroid
# Re-normalize centroid
c_norm = np.linalg.norm(self.centroid)
if c_norm > 0:
centroid_normed = self.centroid / c_norm
else:
centroid_normed = self.centroid
# Check if centroid has shifted enough since last emission
cos_sim = np.dot(centroid_normed, self.last_emit_centroid)
cos_dist = 1.0 - cos_sim
if cos_dist < self.shift_threshold:
return []
self.last_emit_centroid = centroid_normed.copy()
# Retrieve based on smoothed centroid
results = _retrieve(self.embedding_index, self.nexus_conn,
centroid_normed.reshape(1, -1), k=self.search_k)
messages = [content for content, _ in results[:self.top_n]]
lines = _lines_from_messages(messages, self.max_lines)
# Deduplicate against previous emission
lines = [l for l in lines if l not in self.prev_lines]
self.prev_lines = set(_lines_from_messages(
[content for content, _ in results[:self.top_n]], self.max_lines))
return lines
class LayeredMode:
"""
Run multiple timescales simultaneously. Show three sections:
[fast] — what just changed (high threshold, small k)
[mid] — recent theme (EMA with medium alpha)
[slow] — deep undercurrent (EMA with low alpha)
Each layer only emits its section when its own threshold is crossed.
At least one layer must fire for any output.
"""
def __init__(self, embedding_index, nexus_conn, search_k=48, max_lines_per_layer=10):
self.layers = {
'fast': DriftMode(embedding_index, nexus_conn, search_k=search_k,
move_threshold=0.08, amplify=0.7, cooldown=1,
max_lines=max_lines_per_layer),
'mid': FocusMode(embedding_index, nexus_conn, search_k=search_k,
alpha=0.25, shift_threshold=0.03, top_n=16,
max_lines=max_lines_per_layer),
'slow': FocusMode(embedding_index, nexus_conn, search_k=search_k,
alpha=0.05, shift_threshold=0.015, top_n=12,
max_lines=max_lines_per_layer),
}
def step(self, semantic_embedding):
sections = {}
for name, layer in self.layers.items():
lines = layer.step(semantic_embedding)
if lines:
sections[name] = lines
if not sections:
return []
output = []
for name in ['fast', 'mid', 'slow']:
if name in sections:
output.append(f"── {name} ──")
output.extend(sections[name])
output.append("")
return output