"""Utility helpers: math, response enforcement, and valence heuristics. This module provides: - sigmoid + uncertainty helpers - lightweight valence detection (heuristic) - format override + explore trigger detection - response post-processing to enforce selected format """ import json import re import math import numpy as np def sigmoid(x: float) -> float: x = float(np.clip(x, -500, 500)) return 1.0 / (1.0 + math.exp(-x)) def mean_uncertainty(sigma_inv: np.ndarray) -> float: prec = np.diag(sigma_inv) return float(np.mean(1.0 / np.clip(prec, 1e-8, None))) # --- Valence signals --- _POS = re.compile( r"\b(thank|thanks|great|perfect|awesome|love|helpful|useful|exactly|makes sense|clear|brilliant|nice|good|yes|correct|right)\b", re.I, ) _NEG = re.compile( r"\b(wrong|incorrect|confused|confusing|not what|still don.t|don't understand|that.s not|try again|again\b|useless|unhelpful|bad|nope\b|nah\b|wtf\b|huh\??|noooo+)\b", re.I, ) _REPHRASE = re.compile(r"\b(what I mean|let me rephrase|I said|as I mentioned|again|once more)\b", re.I) def fast_valence(message: str, prev_response: str) -> dict: """Return heuristic valence and a human-readable reason string.""" if not prev_response: return {"pos": 0.5, "neg": 0.1, "reason": "first message"} pos_hits = len(_POS.findall(message)) neg_hits = len(_NEG.findall(message)) rephr = 1 if _REPHRASE.search(message) else 0 len_ratio = len(message) / max(len(prev_response), 1) brevity_neg = 0.3 if (len_ratio < 0.08 and len(message) < 15) else 0.0 pos = float(np.clip(0.3 + 0.3 * pos_hits - 0.1 * neg_hits, 0.0, 1.0)) neg = float(np.clip(0.1 + 0.3 * neg_hits + 0.2 * rephr + brevity_neg, 0.0, 1.0)) reasons = [] if pos_hits: reasons.append(f"{pos_hits} positive signal(s)") if neg_hits: reasons.append(f"{neg_hits} negative signal(s)") if rephr: reasons.append("rephrase") if brevity_neg: reasons.append("very short reply") return {"pos": pos, "neg": neg, "reason": ", ".join(reasons) or "neutral"} # --- Explicit format override detection (user mentions what they want) --- _OVERRIDE_PATTERNS = [ ("structured_bullets", re.compile(r"\b(bullets?|bullet points?|list it|in bullets)\b", re.I)), ("step_by_step", re.compile(r"\b(step by step|steps?|walk me through|procedure)\b", re.I)), ("concise_direct", re.compile(r"\b(concise|short|tl;dr|tldr|in 1-3 sentences)\b", re.I)), ("narrative_prose", re.compile(r"\b(paragraph|narrative|in prose|explain like a story)\b", re.I)), ("socratic_questions", re.compile(r"\b(ask me|ask questions|clarifying questions?)\b", re.I)), ("comparison_table", re.compile(r"\b(table|comparison table|pros and cons|compare|vs\.?|versus)\b", re.I)), ("visualization", re.compile(r"\b(chart|plot|graph|visuali[sz]e|visualization|bar chart|pie chart)\b", re.I)), ] def detect_format_override(message: str, available: list[str]) -> str | None: m = (message or "").strip().lower() for strat, pat in _OVERRIDE_PATTERNS: if strat in available and pat.search(m): return strat return None # --- Explore triggers (force trying a different format) --- _EXPLORE_TRIG = re.compile(r"\b(try again|different|another way|not this|still the same|nope\b|nah\b|noooo+)\b", re.I) def detect_explore_trigger(message: str) -> bool: return bool(_EXPLORE_TRIG.search((message or "").strip())) def negative_strength(ev: dict) -> float: """Map heuristic valence to a [0,1] strength scalar.""" if not ev: return 0.0 return float(np.clip(ev.get("neg", 0.0), 0.0, 1.0)) # --- Response enforcement --- def enforce_response(strategy: str, text: str) -> str: """Post-process model output to strongly encourage the chosen format.""" t = (text or "").strip() if strategy == "structured_bullets": parts = re.split(r"[\n]+", t) if len(parts) <= 2 and len(t) > 160: parts = re.split(r"(?<=[.!])\s+", t) cleaned = [] for p in parts: p = p.strip().lstrip("-•* ").strip() if not p: continue if "?" in p: continue cleaned.append(p) cleaned = cleaned[:5] if len(cleaned) < 3: cleaned = cleaned or [t.replace("?", "").strip()] return "\n".join(["- " + c for c in cleaned[:5]]) if strategy == "step_by_step": lines = [ln.strip() for ln in re.split(r"[\n]+", t) if ln.strip()] items = [] for ln in lines: ln = re.sub(r"^([\-*•]|\d+[.)])\s*", "", ln).strip() if ln: items.append(ln) items = items[:6] or [t] return "\n".join([f"{i+1}. {it}" for i, it in enumerate(items[:6])]) if strategy == "concise_direct": sents = re.split(r"(?<=[.!?])\s+", t) return " ".join(sents[:3]).strip() if strategy == "socratic_questions": qs = re.findall(r"[^\n?]*\?", t) if qs: qs = [q.strip() for q in qs if q.strip()][:2] return "Got it.\n" + "\n".join(["- " + q for q in qs]) return "Got it.\n- What outcome do you want?\n- Any constraints or example input/output?" if strategy == "comparison_table": try: obj = json.loads(t) if isinstance(obj, dict) and isinstance(obj.get("columns"), list) and isinstance(obj.get("rows"), list): cols = [str(c).strip() for c in obj.get("columns", [])] rows = obj.get("rows", []) if cols and isinstance(rows, list): header = "| " + " | ".join(cols) + " |" sep = "| " + " | ".join(["---"] * len(cols)) + " |" data_lines = [] for r in rows[:8]: if not isinstance(r, list): continue vals = [str(v).strip() for v in r[: len(cols)]] if len(vals) < len(cols): vals += [""] * (len(cols) - len(vals)) data_lines.append("| " + " | ".join(vals) + " |") return "\n".join([header, sep] + data_lines) if data_lines else "\n".join([header, sep]) except Exception: pass lines = [ln.rstrip() for ln in t.splitlines() if ln.strip()] table_lines = [ln for ln in lines if "|" in ln] if len(table_lines) >= 2: if not any(re.match(r"^\|?\s*[:-]{2,}", ln) for ln in table_lines): header = table_lines[0] cols = [c.strip() for c in header.strip("|").split("|")] sep = "|" + "|".join(["---"] * len(cols)) + "|" return "\n".join([header, sep] + table_lines[1:6]) return "\n".join(table_lines[:8]) fallback = { "columns": ["Option", "Pros", "Cons", "Best for"], "rows": [ ["A", "", "", ""], ["B", "", "", ""], ], } header = "| Option | Pros | Cons | Best for |" sep = "| --- | --- | --- | --- |" rows = ["| A | | | |", "| B | | | |"] return "\n".join([header, sep] + rows) if strategy == "visualization": try: obj = json.loads(t) if ( isinstance(obj, dict) and isinstance(obj.get("type"), str) and isinstance(obj.get("labels"), list) and isinstance(obj.get("values"), list) ): labels = [str(x) for x in obj.get("labels", [])] values = obj.get("values", []) pairs: list[tuple[str, float]] = [] for i, lab in enumerate(labels): try: val = float(values[i]) if i < len(values) else 0.0 except Exception: val = 0.0 pairs.append((lab, val)) max_val = max([abs(v) for _, v in pairs], default=1.0) or 1.0 lines = [] for lab, val in pairs[:10]: width = int(round((abs(val) / max_val) * 24)) bar = "#" * max(1, width) lines.append(f"{lab:<12} | {bar} {val:g}") body = "\n".join(lines) if lines else "A | #### 1\nB | #### 1" return f"```text\n{body}\n```" except Exception: pass m = re.search(r"```[\s\S]*?```", t) if m: return m.group(0) fallback = { "type": "bar", "title": "Visualization", "labels": ["A", "B"], "values": [1, 1], "x_label": "Category", "y_label": "Value", } return "```text\nA | #### 1\nB | #### 1\n```" return t