Aluode's picture
Upload 5 files
f65b63e verified
"""
ConjunctionReservoir — core retriever
"""
import numpy as np
import re
import time
from typing import Dict, List, Optional, Tuple, Union
def split_sentences(text: str, min_len: int = 15) -> List[str]:
return [s.strip() for s in re.split(r"[.!?]+", text) if len(s.strip()) >= min_len]
def chunk_document(text: str, chunk_size: int = 400, overlap: int = 50) -> List[str]:
sections = re.split(r"\n(?=From:|Subject:|Date:|---)", text)
chunks = []
for section in sections:
section = section.strip()
if len(section) < 50:
continue
if len(section) <= chunk_size:
chunks.append(section)
else:
for i in range(0, len(section), chunk_size - overlap):
chunk = section[i : i + chunk_size].strip()
if len(chunk) > 50:
chunks.append(chunk)
return chunks
def build_vocab(texts: List[str], max_vocab: int = 2000) -> Dict[str, int]:
counts: Dict[str, int] = {}
for t in texts:
for w in re.findall(r"\b[a-zA-Z]{2,}\b", t.lower()):
counts[w] = counts.get(w, 0) + 1
return {
word: idx
for idx, (word, _) in enumerate(
sorted(counts.items(), key=lambda x: -x[1])[:max_vocab]
)
}
def tfidf_weights(sentences: List[str], vocab: Dict[str, int]) -> np.ndarray:
n = len(sentences)
df = np.zeros(len(vocab))
for s in sentences:
for w in set(re.findall(r"\b[a-zA-Z]{2,}\b", s.lower())):
if w in vocab:
df[vocab[w]] += 1
return np.log((n + 1) / (df + 1)) + 1.0
def encode_text(text: str, vocab: Dict[str, int], idf: np.ndarray) -> np.ndarray:
vec = np.zeros(len(vocab))
for w in re.findall(r"\b[a-zA-Z]{2,}\b", text.lower()):
if w in vocab:
vec[vocab[w]] += 1.0
vec *= idf
norm = np.linalg.norm(vec)
return vec / (norm + 1e-8)
class ConjunctionReservoir:
"""
Document retriever with sentence-level conjunction scoring.
"""
def __init__(
self,
conjunction_threshold: float = 0.5,
coverage_decay: float = 0.04,
hebbian_lr: float = 0.01,
max_vocab: int = 2000,
) -> None:
self.conjunction_threshold = conjunction_threshold
self.coverage_decay = coverage_decay
self.hebbian_lr = hebbian_lr
self.max_vocab = max_vocab
self.vocab: Optional[Dict[str, int]] = None
self.idf: Optional[np.ndarray] = None
self.chunk_texts: List[str] = []
self.all_sentences: List[str] = []
self.sentence_to_chunk: List[int] = []
self.sent_feats: Optional[np.ndarray] = None
self.chunk_feats: Optional[np.ndarray] = None
self.sentence_coverage: Optional[np.ndarray] = None
self.n_queries: int = 0
self.index_time: float = 0.0
def build_index(
self,
text_or_chunks: Union[str, List[str]],
verbose: bool = True,
) -> "ConjunctionReservoir":
t0 = time.perf_counter()
if isinstance(text_or_chunks, str):
self.chunk_texts = chunk_document(text_or_chunks)
else:
self.chunk_texts = list(text_or_chunks)
if not self.chunk_texts:
raise ValueError("No chunks found.")
self.all_sentences = []
self.sentence_to_chunk = []
for chunk_idx, chunk in enumerate(self.chunk_texts):
for s in split_sentences(chunk):
self.all_sentences.append(s)
self.sentence_to_chunk.append(chunk_idx)
if not self.all_sentences:
raise ValueError("No sentences extracted.")
self.vocab = build_vocab(
self.all_sentences + self.chunk_texts, max_vocab=self.max_vocab
)
self.idf = tfidf_weights(self.all_sentences, self.vocab)
self.sent_feats = np.array(
[encode_text(s, self.vocab, self.idf) for s in self.all_sentences]
)
self.chunk_feats = np.array(
[encode_text(c, self.vocab, self.idf) for c in self.chunk_texts]
)
self.sentence_coverage = np.zeros(len(self.all_sentences))
self.index_time = time.perf_counter() - t0
return self
def retrieve(
self,
query: str,
top_k: int = 5,
update_coverage: bool = True,
) -> List[Tuple[str, float]]:
if self.vocab is None:
raise RuntimeError("Call build_index() before retrieve().")
q_tokens = set(re.findall(r"\b[a-zA-Z]{3,}\b", query.lower()))
q_feat = encode_text(query, self.vocab, self.idf)
sent_scores = np.zeros(len(self.all_sentences))
for s_idx, sentence in enumerate(self.all_sentences):
s_tokens = set(re.findall(r"\b[a-zA-Z]{3,}\b", sentence.lower()))
matches = sum(
1 for qt in q_tokens
if any(qt in st or st in qt for st in s_tokens)
)
token_coverage = matches / len(q_tokens) if q_tokens else 0.0
if token_coverage < self.conjunction_threshold:
continue
tfidf_sim = float(self.sent_feats[s_idx] @ q_feat)
conj_weight = token_coverage ** 2
vollan_w = 1.0 / (1.0 + self.sentence_coverage[s_idx])
sent_scores[s_idx] = tfidf_sim * conj_weight * vollan_w
chunk_scores = np.zeros(len(self.chunk_texts))
for s_idx, (score, chunk_idx) in enumerate(zip(sent_scores, self.sentence_to_chunk)):
if score > chunk_scores[chunk_idx]:
chunk_scores[chunk_idx] = score
if chunk_scores.max() == 0.0:
chunk_scores = self.chunk_feats @ q_feat
top_idx = chunk_scores.argsort()[-top_k:][::-1]
results = [(self.chunk_texts[i], float(chunk_scores[i])) for i in top_idx]
if update_coverage and sent_scores.max() > 0.0:
norm = sent_scores / (sent_scores.max() + 1e-8)
self.sentence_coverage = (
self.sentence_coverage * (1.0 - self.coverage_decay) + norm
)
self.n_queries += 1
return results
def summary(self) -> Dict:
return {
"n_chunks": len(self.chunk_texts),
"n_sentences": len(self.all_sentences),
"avg_sentences_per_chunk": round(
len(self.all_sentences) / max(1, len(self.chunk_texts)), 2
),
"vocab_size": len(self.vocab) if self.vocab else 0,
"conjunction_threshold": self.conjunction_threshold,
"coverage_decay": self.coverage_decay,
"n_queries": self.n_queries,
"index_time_ms": round(self.index_time * 1000, 1),
}
def coverage_profile(self) -> Dict:
if self.sentence_coverage is None:
return {}
top_idx = self.sentence_coverage.argsort()[-10:][::-1]
return {
"most_covered": [
(self.all_sentences[i], round(float(self.sentence_coverage[i]), 4))
for i in top_idx
if self.sentence_coverage[i] > 0
],
"mean_coverage": round(float(self.sentence_coverage.mean()), 6),
"n_queries": self.n_queries,
}