Buckets:
| """Omega++ reasoning helpers around the base language model. | |
| Upgraded: self-consistency voting, semantic retrieval, CoT verification. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import re | |
| from dataclasses import dataclass | |
| def _tokens(text: str) -> set[str]: | |
| return {t.lower() for t in re.findall(r"[\wก-๙]+", text, flags=re.UNICODE) if len(t) > 1} | |
| def _tfidf_score(query_tokens: set[str], doc_tokens: set[str], idf: dict[str, float]) -> float: | |
| """TF-IDF weighted Jaccard — คะแนนดีกว่า simple overlap""" | |
| if not query_tokens or not doc_tokens: | |
| return 0.0 | |
| overlap = query_tokens & doc_tokens | |
| if not overlap: | |
| return 0.0 | |
| score = sum(idf.get(t, 1.0) for t in overlap) | |
| denom = math.sqrt( | |
| sum(idf.get(t, 1.0) ** 2 for t in query_tokens) * | |
| sum(idf.get(t, 1.0) ** 2 for t in doc_tokens) | |
| ) | |
| return score / max(denom, 1e-9) | |
| class RetrievalMemory: | |
| def __init__(self, max_size: int = 10_000) -> None: | |
| self._items: list[tuple[str, str, set[str]]] = [] | |
| self._max_size = max_size | |
| self._doc_freq: dict[str, int] = {} | |
| def add(self, key: str, content: str) -> None: | |
| tokens = _tokens(f"{key} {content}") | |
| if len(self._items) >= self._max_size: | |
| removed = self._items.pop(0) | |
| for t in removed[2]: | |
| self._doc_freq[t] = max(0, self._doc_freq.get(t, 0) - 1) | |
| self._items.append((key, content, tokens)) | |
| for t in tokens: | |
| self._doc_freq[t] = self._doc_freq.get(t, 0) + 1 | |
| def _idf(self, token: str) -> float: | |
| n = len(self._items) | |
| df = self._doc_freq.get(token, 0) | |
| return math.log((n + 1) / (df + 1)) + 1.0 | |
| def search(self, query: str, top_k: int = 4) -> list[str]: | |
| q = _tokens(query) | |
| if not q: | |
| return [] | |
| idf = {t: self._idf(t) for t in q} | |
| scored: list[tuple[float, str]] = [] | |
| for _, content, item_tokens in self._items: | |
| score = _tfidf_score(q, item_tokens, idf) | |
| if score > 0: | |
| scored.append((score, content)) | |
| scored.sort(key=lambda x: x[0], reverse=True) | |
| return [c for _, c in scored[:top_k]] | |
| def add_bulk(self, items: list[tuple[str, str]]) -> None: | |
| for key, content in items: | |
| self.add(key, content) | |
| class VerificationResult: | |
| score: float | |
| accepted: bool | |
| reasons: tuple[str, ...] | |
| class OmegaPlusLogic: | |
| def __init__( | |
| self, | |
| memory: RetrievalMemory | None = None, | |
| min_score: float = 0.35, | |
| enable_cot: bool = True, | |
| ) -> None: | |
| self.memory = memory or RetrievalMemory() | |
| self.min_score = min_score | |
| self.enable_cot = enable_cot | |
| # ─── Prompt Building ───────────────────────────────────────────────────── | |
| def enhance_prompt(self, question: str, system_prompt: str, top_k: int = 4) -> str: | |
| memories = self.memory.search(question, top_k=top_k) | |
| context = "\n".join(f"- {m}" for m in memories) | |
| if not context: | |
| context = "- ไม่พบข้อมูลที่ตรงกัน ใช้ความรู้จากการ train และระบุความไม่แน่ใจ" | |
| if self.enable_cot: | |
| cot_instruction = ( | |
| "\nวิธีตอบ: คิดวิเคราะห์ใน <think>...</think> แล้วให้คำตอบใน <answer>...</answer>" | |
| ) | |
| else: | |
| cot_instruction = "" | |
| return ( | |
| f"<system>{system_prompt}{cot_instruction}\n" | |
| f"ข้อมูลที่เกี่ยวข้อง:\n{context}</system>\n" | |
| f"<user>{question}</user>\n<assistant>" | |
| ) | |
| def build_self_consistency_prompt( | |
| self, question: str, system_prompt: str, attempt: int = 0 | |
| ) -> str: | |
| """สร้าง prompt หลายเวอร์ชันสำหรับ self-consistency sampling""" | |
| temps = [ | |
| "คิดอย่างรอบคอบแล้วตอบ:", | |
| "วิเคราะห์จากมุมมองต่างๆ:", | |
| "ใช้เหตุผลเชิงตรรกะ:", | |
| "ตรวจสอบความถูกต้องก่อนตอบ:", | |
| ] | |
| en_temps = [ | |
| "Think carefully and answer:", | |
| "Analyze from multiple angles:", | |
| "Use logical reasoning:", | |
| "Verify your reasoning before answering:", | |
| ] | |
| is_thai = sum(1 for c in question if "" <= c <= "") / max(len(question), 1) > 0.1 | |
| prefix_list = temps if is_thai else en_temps | |
| prefix = prefix_list[attempt % len(prefix_list)] | |
| return self.enhance_prompt(f"{prefix} {question}", system_prompt) | |
| # ─── Verification ───────────────────────────────────────────────────────── | |
| def verify_answer(self, question: str, answer: str) -> VerificationResult: | |
| reasons: list[str] = [] | |
| answer = answer.strip() | |
| if not answer: | |
| return VerificationResult(0.0, False, ("empty answer",)) | |
| # Extract actual answer from CoT if present | |
| cot_match = re.search(r"<answer>([\s\S]*?)</answer>", answer, re.IGNORECASE) | |
| clean_answer = cot_match.group(1).strip() if cot_match else answer | |
| q_tokens = _tokens(question) | |
| a_tokens = _tokens(clean_answer) | |
| overlap = len(q_tokens & a_tokens) / max(len(q_tokens), 1) | |
| length_score = min(len(clean_answer) / 200.0, 1.0) | |
| has_reasoning = bool(re.search(r"<think>[\s\S]+</think>", answer, re.IGNORECASE)) | |
| reasoning_bonus = 0.15 if has_reasoning else 0.0 | |
| uncertainty_penalty = 0.2 if any( | |
| x in answer.lower() for x in ["ไม่แน่ใจ", "i don't know", "not sure", "ไม่ทราบ"] | |
| ) else 0.0 | |
| score = max(0.0, min(1.0, | |
| 0.35 * overlap + 0.40 * length_score + reasoning_bonus + 0.10 - uncertainty_penalty | |
| )) | |
| if overlap == 0: | |
| reasons.append("low lexical grounding") | |
| if len(clean_answer) < 12: | |
| reasons.append("too short") | |
| if uncertainty_penalty: | |
| reasons.append("uncertain language") | |
| if not has_reasoning and self.enable_cot: | |
| reasons.append("no reasoning trace") | |
| return VerificationResult(score, score >= self.min_score, tuple(reasons)) | |
| # ─── Self-Consistency ───────────────────────────────────────────────────── | |
| def vote_best_answer( | |
| self, | |
| question: str, | |
| candidates: list[str], | |
| ) -> tuple[str, VerificationResult]: | |
| """Majority vote + quality filter — เลือกคำตอบที่ดีที่สุด""" | |
| if not candidates: | |
| return "", VerificationResult(0.0, False, ("no candidates",)) | |
| # Extract clean answers | |
| def clean(text: str) -> str: | |
| m = re.search(r"<answer>([\s\S]*?)</answer>", text, re.IGNORECASE) | |
| return m.group(1).strip() if m else text.strip() | |
| clean_answers = [clean(c) for c in candidates] | |
| # Jaccard-based agreement score | |
| def jaccard(a: str, b: str) -> float: | |
| ta = set(re.findall(r"\w+", a.lower())) | |
| tb = set(re.findall(r"\w+", b.lower())) | |
| if not ta or not tb: | |
| return 0.0 | |
| return len(ta & tb) / len(ta | tb) | |
| agree_scores: list[float] = [] | |
| for i, ca in enumerate(clean_answers): | |
| agreement = sum( | |
| jaccard(ca, clean_answers[j]) | |
| for j in range(len(clean_answers)) if j != i | |
| ) / max(len(clean_answers) - 1, 1) | |
| agree_scores.append(agreement) | |
| # Combine agreement + quality | |
| quality_scores = [self.verify_answer(question, c).score for c in candidates] | |
| final_scores = [0.6 * a + 0.4 * q for a, q in zip(agree_scores, quality_scores)] | |
| best_idx = max(range(len(final_scores)), key=lambda i: final_scores[i]) | |
| best = candidates[best_idx] | |
| result = self.verify_answer(question, best) | |
| return best, result | |
| def rerank(self, question: str, candidates: list[str]) -> tuple[str, VerificationResult]: | |
| return self.vote_best_answer(question, candidates) | |
Xet Storage Details
- Size:
- 8.75 kB
- Xet hash:
- c342af2495d2ad8d01d7fdf3363339e281aadb0ae6bfaab549d5fe74c576df2d
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.