bbkdevops's picture
download
raw
8.75 kB
"""Omega++ reasoning helpers around the base language model.
Upgraded: self-consistency voting, semantic retrieval, CoT verification.
"""
from __future__ import annotations
import math
import re
from dataclasses import dataclass
def _tokens(text: str) -> set[str]:
return {t.lower() for t in re.findall(r"[\wก-๙]+", text, flags=re.UNICODE) if len(t) > 1}
def _tfidf_score(query_tokens: set[str], doc_tokens: set[str], idf: dict[str, float]) -> float:
"""TF-IDF weighted Jaccard — คะแนนดีกว่า simple overlap"""
if not query_tokens or not doc_tokens:
return 0.0
overlap = query_tokens & doc_tokens
if not overlap:
return 0.0
score = sum(idf.get(t, 1.0) for t in overlap)
denom = math.sqrt(
sum(idf.get(t, 1.0) ** 2 for t in query_tokens) *
sum(idf.get(t, 1.0) ** 2 for t in doc_tokens)
)
return score / max(denom, 1e-9)
class RetrievalMemory:
def __init__(self, max_size: int = 10_000) -> None:
self._items: list[tuple[str, str, set[str]]] = []
self._max_size = max_size
self._doc_freq: dict[str, int] = {}
def add(self, key: str, content: str) -> None:
tokens = _tokens(f"{key} {content}")
if len(self._items) >= self._max_size:
removed = self._items.pop(0)
for t in removed[2]:
self._doc_freq[t] = max(0, self._doc_freq.get(t, 0) - 1)
self._items.append((key, content, tokens))
for t in tokens:
self._doc_freq[t] = self._doc_freq.get(t, 0) + 1
def _idf(self, token: str) -> float:
n = len(self._items)
df = self._doc_freq.get(token, 0)
return math.log((n + 1) / (df + 1)) + 1.0
def search(self, query: str, top_k: int = 4) -> list[str]:
q = _tokens(query)
if not q:
return []
idf = {t: self._idf(t) for t in q}
scored: list[tuple[float, str]] = []
for _, content, item_tokens in self._items:
score = _tfidf_score(q, item_tokens, idf)
if score > 0:
scored.append((score, content))
scored.sort(key=lambda x: x[0], reverse=True)
return [c for _, c in scored[:top_k]]
def add_bulk(self, items: list[tuple[str, str]]) -> None:
for key, content in items:
self.add(key, content)
@dataclass(frozen=True)
class VerificationResult:
score: float
accepted: bool
reasons: tuple[str, ...]
class OmegaPlusLogic:
def __init__(
self,
memory: RetrievalMemory | None = None,
min_score: float = 0.35,
enable_cot: bool = True,
) -> None:
self.memory = memory or RetrievalMemory()
self.min_score = min_score
self.enable_cot = enable_cot
# ─── Prompt Building ─────────────────────────────────────────────────────
def enhance_prompt(self, question: str, system_prompt: str, top_k: int = 4) -> str:
memories = self.memory.search(question, top_k=top_k)
context = "\n".join(f"- {m}" for m in memories)
if not context:
context = "- ไม่พบข้อมูลที่ตรงกัน ใช้ความรู้จากการ train และระบุความไม่แน่ใจ"
if self.enable_cot:
cot_instruction = (
"\nวิธีตอบ: คิดวิเคราะห์ใน <think>...</think> แล้วให้คำตอบใน <answer>...</answer>"
)
else:
cot_instruction = ""
return (
f"<system>{system_prompt}{cot_instruction}\n"
f"ข้อมูลที่เกี่ยวข้อง:\n{context}</system>\n"
f"<user>{question}</user>\n<assistant>"
)
def build_self_consistency_prompt(
self, question: str, system_prompt: str, attempt: int = 0
) -> str:
"""สร้าง prompt หลายเวอร์ชันสำหรับ self-consistency sampling"""
temps = [
"คิดอย่างรอบคอบแล้วตอบ:",
"วิเคราะห์จากมุมมองต่างๆ:",
"ใช้เหตุผลเชิงตรรกะ:",
"ตรวจสอบความถูกต้องก่อนตอบ:",
]
en_temps = [
"Think carefully and answer:",
"Analyze from multiple angles:",
"Use logical reasoning:",
"Verify your reasoning before answering:",
]
is_thai = sum(1 for c in question if "฀" <= c <= "๿") / max(len(question), 1) > 0.1
prefix_list = temps if is_thai else en_temps
prefix = prefix_list[attempt % len(prefix_list)]
return self.enhance_prompt(f"{prefix} {question}", system_prompt)
# ─── Verification ─────────────────────────────────────────────────────────
def verify_answer(self, question: str, answer: str) -> VerificationResult:
reasons: list[str] = []
answer = answer.strip()
if not answer:
return VerificationResult(0.0, False, ("empty answer",))
# Extract actual answer from CoT if present
cot_match = re.search(r"<answer>([\s\S]*?)</answer>", answer, re.IGNORECASE)
clean_answer = cot_match.group(1).strip() if cot_match else answer
q_tokens = _tokens(question)
a_tokens = _tokens(clean_answer)
overlap = len(q_tokens & a_tokens) / max(len(q_tokens), 1)
length_score = min(len(clean_answer) / 200.0, 1.0)
has_reasoning = bool(re.search(r"<think>[\s\S]+</think>", answer, re.IGNORECASE))
reasoning_bonus = 0.15 if has_reasoning else 0.0
uncertainty_penalty = 0.2 if any(
x in answer.lower() for x in ["ไม่แน่ใจ", "i don't know", "not sure", "ไม่ทราบ"]
) else 0.0
score = max(0.0, min(1.0,
0.35 * overlap + 0.40 * length_score + reasoning_bonus + 0.10 - uncertainty_penalty
))
if overlap == 0:
reasons.append("low lexical grounding")
if len(clean_answer) < 12:
reasons.append("too short")
if uncertainty_penalty:
reasons.append("uncertain language")
if not has_reasoning and self.enable_cot:
reasons.append("no reasoning trace")
return VerificationResult(score, score >= self.min_score, tuple(reasons))
# ─── Self-Consistency ─────────────────────────────────────────────────────
def vote_best_answer(
self,
question: str,
candidates: list[str],
) -> tuple[str, VerificationResult]:
"""Majority vote + quality filter — เลือกคำตอบที่ดีที่สุด"""
if not candidates:
return "", VerificationResult(0.0, False, ("no candidates",))
# Extract clean answers
def clean(text: str) -> str:
m = re.search(r"<answer>([\s\S]*?)</answer>", text, re.IGNORECASE)
return m.group(1).strip() if m else text.strip()
clean_answers = [clean(c) for c in candidates]
# Jaccard-based agreement score
def jaccard(a: str, b: str) -> float:
ta = set(re.findall(r"\w+", a.lower()))
tb = set(re.findall(r"\w+", b.lower()))
if not ta or not tb:
return 0.0
return len(ta & tb) / len(ta | tb)
agree_scores: list[float] = []
for i, ca in enumerate(clean_answers):
agreement = sum(
jaccard(ca, clean_answers[j])
for j in range(len(clean_answers)) if j != i
) / max(len(clean_answers) - 1, 1)
agree_scores.append(agreement)
# Combine agreement + quality
quality_scores = [self.verify_answer(question, c).score for c in candidates]
final_scores = [0.6 * a + 0.4 * q for a, q in zip(agree_scores, quality_scores)]
best_idx = max(range(len(final_scores)), key=lambda i: final_scores[i])
best = candidates[best_idx]
result = self.verify_answer(question, best)
return best, result
def rerank(self, question: str, candidates: list[str]) -> tuple[str, VerificationResult]:
return self.vote_best_answer(question, candidates)

Xet Storage Details

Size:
8.75 kB
·
Xet hash:
c342af2495d2ad8d01d7fdf3363339e281aadb0ae6bfaab549d5fe74c576df2d

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.