saivivek6's picture
Deploy vivek app to Space: Vue UI (bundled dist), streaming widgets, JSON schema renderer
c97e8a9
"""Utility helpers: math, response enforcement, and valence heuristics.
This module provides:
- sigmoid + uncertainty helpers
- lightweight valence detection (heuristic)
- format override + explore trigger detection
- response post-processing to enforce selected format
"""
import json
import re
import math
import numpy as np
def sigmoid(x: float) -> float:
x = float(np.clip(x, -500, 500))
return 1.0 / (1.0 + math.exp(-x))
def mean_uncertainty(sigma_inv: np.ndarray) -> float:
prec = np.diag(sigma_inv)
return float(np.mean(1.0 / np.clip(prec, 1e-8, None)))
# --- Valence signals ---
_POS = re.compile(
r"\b(thank|thanks|great|perfect|awesome|love|helpful|useful|exactly|makes sense|clear|brilliant|nice|good|yes|correct|right)\b",
re.I,
)
_NEG = re.compile(
r"\b(wrong|incorrect|confused|confusing|not what|still don.t|don't understand|that.s not|try again|again\b|useless|unhelpful|bad|nope\b|nah\b|wtf\b|huh\??|noooo+)\b",
re.I,
)
_REPHRASE = re.compile(r"\b(what I mean|let me rephrase|I said|as I mentioned|again|once more)\b", re.I)
def fast_valence(message: str, prev_response: str) -> dict:
"""Return heuristic valence and a human-readable reason string."""
if not prev_response:
return {"pos": 0.5, "neg": 0.1, "reason": "first message"}
pos_hits = len(_POS.findall(message))
neg_hits = len(_NEG.findall(message))
rephr = 1 if _REPHRASE.search(message) else 0
len_ratio = len(message) / max(len(prev_response), 1)
brevity_neg = 0.3 if (len_ratio < 0.08 and len(message) < 15) else 0.0
pos = float(np.clip(0.3 + 0.3 * pos_hits - 0.1 * neg_hits, 0.0, 1.0))
neg = float(np.clip(0.1 + 0.3 * neg_hits + 0.2 * rephr + brevity_neg, 0.0, 1.0))
reasons = []
if pos_hits:
reasons.append(f"{pos_hits} positive signal(s)")
if neg_hits:
reasons.append(f"{neg_hits} negative signal(s)")
if rephr:
reasons.append("rephrase")
if brevity_neg:
reasons.append("very short reply")
return {"pos": pos, "neg": neg, "reason": ", ".join(reasons) or "neutral"}
# --- Explicit format override detection (user mentions what they want) ---
_OVERRIDE_PATTERNS = [
("structured_bullets", re.compile(r"\b(bullets?|bullet points?|list it|in bullets)\b", re.I)),
("step_by_step", re.compile(r"\b(step by step|steps?|walk me through|procedure)\b", re.I)),
("concise_direct", re.compile(r"\b(concise|short|tl;dr|tldr|in 1-3 sentences)\b", re.I)),
("narrative_prose", re.compile(r"\b(paragraph|narrative|in prose|explain like a story)\b", re.I)),
("socratic_questions", re.compile(r"\b(ask me|ask questions|clarifying questions?)\b", re.I)),
("comparison_table", re.compile(r"\b(table|comparison table|pros and cons|compare|vs\.?|versus)\b", re.I)),
("visualization", re.compile(r"\b(chart|plot|graph|visuali[sz]e|visualization|bar chart|pie chart)\b", re.I)),
]
def detect_format_override(message: str, available: list[str]) -> str | None:
m = (message or "").strip().lower()
for strat, pat in _OVERRIDE_PATTERNS:
if strat in available and pat.search(m):
return strat
return None
# --- Explore triggers (force trying a different format) ---
_EXPLORE_TRIG = re.compile(r"\b(try again|different|another way|not this|still the same|nope\b|nah\b|noooo+)\b", re.I)
def detect_explore_trigger(message: str) -> bool:
return bool(_EXPLORE_TRIG.search((message or "").strip()))
def negative_strength(ev: dict) -> float:
"""Map heuristic valence to a [0,1] strength scalar."""
if not ev:
return 0.0
return float(np.clip(ev.get("neg", 0.0), 0.0, 1.0))
# --- Response enforcement ---
def enforce_response(strategy: str, text: str) -> str:
"""Post-process model output to strongly encourage the chosen format."""
t = (text or "").strip()
if strategy == "structured_bullets":
parts = re.split(r"[\n]+", t)
if len(parts) <= 2 and len(t) > 160:
parts = re.split(r"(?<=[.!])\s+", t)
cleaned = []
for p in parts:
p = p.strip().lstrip("-•* ").strip()
if not p:
continue
if "?" in p:
continue
cleaned.append(p)
cleaned = cleaned[:5]
if len(cleaned) < 3:
cleaned = cleaned or [t.replace("?", "").strip()]
return "\n".join(["- " + c for c in cleaned[:5]])
if strategy == "step_by_step":
lines = [ln.strip() for ln in re.split(r"[\n]+", t) if ln.strip()]
items = []
for ln in lines:
ln = re.sub(r"^([\-*•]|\d+[.)])\s*", "", ln).strip()
if ln:
items.append(ln)
items = items[:6] or [t]
return "\n".join([f"{i+1}. {it}" for i, it in enumerate(items[:6])])
if strategy == "concise_direct":
sents = re.split(r"(?<=[.!?])\s+", t)
return " ".join(sents[:3]).strip()
if strategy == "socratic_questions":
qs = re.findall(r"[^\n?]*\?", t)
if qs:
qs = [q.strip() for q in qs if q.strip()][:2]
return "Got it.\n" + "\n".join(["- " + q for q in qs])
return "Got it.\n- What outcome do you want?\n- Any constraints or example input/output?"
if strategy == "comparison_table":
try:
obj = json.loads(t)
if isinstance(obj, dict) and isinstance(obj.get("columns"), list) and isinstance(obj.get("rows"), list):
cols = [str(c).strip() for c in obj.get("columns", [])]
rows = obj.get("rows", [])
if cols and isinstance(rows, list):
header = "| " + " | ".join(cols) + " |"
sep = "| " + " | ".join(["---"] * len(cols)) + " |"
data_lines = []
for r in rows[:8]:
if not isinstance(r, list):
continue
vals = [str(v).strip() for v in r[: len(cols)]]
if len(vals) < len(cols):
vals += [""] * (len(cols) - len(vals))
data_lines.append("| " + " | ".join(vals) + " |")
return "\n".join([header, sep] + data_lines) if data_lines else "\n".join([header, sep])
except Exception:
pass
lines = [ln.rstrip() for ln in t.splitlines() if ln.strip()]
table_lines = [ln for ln in lines if "|" in ln]
if len(table_lines) >= 2:
if not any(re.match(r"^\|?\s*[:-]{2,}", ln) for ln in table_lines):
header = table_lines[0]
cols = [c.strip() for c in header.strip("|").split("|")]
sep = "|" + "|".join(["---"] * len(cols)) + "|"
return "\n".join([header, sep] + table_lines[1:6])
return "\n".join(table_lines[:8])
fallback = {
"columns": ["Option", "Pros", "Cons", "Best for"],
"rows": [
["A", "", "", ""],
["B", "", "", ""],
],
}
header = "| Option | Pros | Cons | Best for |"
sep = "| --- | --- | --- | --- |"
rows = ["| A | | | |", "| B | | | |"]
return "\n".join([header, sep] + rows)
if strategy == "visualization":
try:
obj = json.loads(t)
if (
isinstance(obj, dict)
and isinstance(obj.get("type"), str)
and isinstance(obj.get("labels"), list)
and isinstance(obj.get("values"), list)
):
labels = [str(x) for x in obj.get("labels", [])]
values = obj.get("values", [])
pairs: list[tuple[str, float]] = []
for i, lab in enumerate(labels):
try:
val = float(values[i]) if i < len(values) else 0.0
except Exception:
val = 0.0
pairs.append((lab, val))
max_val = max([abs(v) for _, v in pairs], default=1.0) or 1.0
lines = []
for lab, val in pairs[:10]:
width = int(round((abs(val) / max_val) * 24))
bar = "#" * max(1, width)
lines.append(f"{lab:<12} | {bar} {val:g}")
body = "\n".join(lines) if lines else "A | #### 1\nB | #### 1"
return f"```text\n{body}\n```"
except Exception:
pass
m = re.search(r"```[\s\S]*?```", t)
if m:
return m.group(0)
fallback = {
"type": "bar",
"title": "Visualization",
"labels": ["A", "B"],
"values": [1, 1],
"x_label": "Category",
"y_label": "Value",
}
return "```text\nA | #### 1\nB | #### 1\n```"
return t