fades-api / utils.py
maxime-antoine-dev's picture
Update utils.py
e09b7e5 verified
import json
import re
from typing import Any, Dict, Optional, List
# If prompts.py doesn't exist, keep a safe fallback
try:
from prompts import ALLOWED_LABELS # type: ignore
except Exception:
ALLOWED_LABELS = [
"none", "faulty generalization", "false causality", "circular reasoning",
"ad populum", "ad hominem", "fallacy of logic", "appeal to emotion",
"false dilemma", "equivocation", "fallacy of extension",
"fallacy of relevance", "fallacy of credibility", "miscellaneous", "intentional"
]
# ----------------------------
# Robust JSON extraction
# ----------------------------
def stop_at_complete_json(text: str) -> Optional[str]:
start = text.find("{")
if start == -1:
return None
depth = 0
in_str = False
esc = False
for i in range(start, len(text)):
ch = text[i]
if in_str:
if esc:
esc = False
elif ch == "\\":
esc = True
elif ch == '"':
in_str = False
continue
if ch == '"':
in_str = True
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
return None
def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
cut = stop_at_complete_json(s) or s
start = cut.find("{")
end = cut.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
cand = cut[start : end + 1].strip()
try:
return json.loads(cand)
except Exception:
return None
# ----------------------------
# Extra robustness: remove stray unquoted fields (e.g., `confidence: 0.75`)
# that sometimes appear outside JSON strings due to generation glitches.
# ----------------------------
def _remove_unquoted_confidence_field(json_text: str) -> str:
"""
Removes an unquoted trailing field like `confidence: 0.75` that appears
outside strings in otherwise-valid JSON output. This is a targeted fix
for common LLM glitches and intentionally conservative (only triggers
when we are NOT inside a quoted string).
"""
out_chars: List[str] = []
i = 0
in_str = False
esc = False
def _pop_trailing_ws_and_optional_comma():
# remove trailing whitespace
while out_chars and out_chars[-1].isspace():
out_chars.pop()
# remove trailing comma (and whitespace before it)
if out_chars and out_chars[-1] == ",":
out_chars.pop()
while out_chars and out_chars[-1].isspace():
out_chars.pop()
while i < len(json_text):
ch = json_text[i]
if in_str:
out_chars.append(ch)
if esc:
esc = False
elif ch == "\\": # escape
esc = True
elif ch == '"':
in_str = False
i += 1
continue
if ch == '"':
in_str = True
out_chars.append(ch)
i += 1
continue
# Detect an unquoted `confidence: <number>` outside strings.
# Only remove if followed by a number and then a delimiter (`,` or `}`).
if json_text.startswith("confidence", i):
j = i + len("confidence")
while j < len(json_text) and json_text[j].isspace():
j += 1
if j < len(json_text) and json_text[j] == ":":
j += 1
while j < len(json_text) and json_text[j].isspace():
j += 1
# parse a simple number
if j < len(json_text) and json_text[j] in "+-":
j += 1
has_digit = False
while j < len(json_text) and json_text[j].isdigit():
has_digit = True
j += 1
if j < len(json_text) and json_text[j] == ".":
j += 1
while j < len(json_text) and json_text[j].isdigit():
has_digit = True
j += 1
if has_digit:
k = j
while k < len(json_text) and json_text[k].isspace():
k += 1
if k < len(json_text) and json_text[k] in {",", "}"}:
_pop_trailing_ws_and_optional_comma()
i = k # keep delimiter
continue
out_chars.append(ch)
i += 1
return "".join(out_chars)
def extract_json_obj_robust(s: str) -> Optional[Dict[str, Any]]:
"""
Extract and parse the first JSON object from a model output string.
- Cuts at the first complete `{...}` (brace-balanced while respecting strings).
- Repairs a common glitch: an unquoted trailing `confidence: <num>`.
- Returns a dict if parsing succeeds, else None.
"""
cut = stop_at_complete_json(s) or s
start = cut.find("{")
end = cut.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
cand = cut[start : end + 1].strip()
cand = cand.replace("```json", "").replace("```", "").strip()
cand = _remove_unquoted_confidence_field(cand)
try:
return json.loads(cand)
except Exception:
return None
# ----------------------------
# Post-processing: remove template sentence
# ----------------------------
_TEMPLATE_RE = re.compile(
r"\bthe input contains fallacious reasoning consistent with the predicted type\(s\)\b\.?",
flags=re.IGNORECASE,
)
def strip_template_sentence(text: str) -> str:
if not isinstance(text, str):
return ""
out = _TEMPLATE_RE.sub("", text)
out = out.replace("..", ".").strip()
out = re.sub(r"\s{2,}", " ", out)
out = re.sub(r"^\s*[\-–—:;\.\s]+", "", out).strip()
return out
# ----------------------------
# Output sanitation / validation
# ----------------------------
def _clamp01(x: Any, default: float = 0.5) -> float:
try:
v = float(x)
except Exception:
return default
return 0.0 if v < 0.0 else (1.0 if v > 1.0 else v)
def _is_allowed_label(lbl: Any) -> bool:
return isinstance(lbl, str) and lbl in ALLOWED_LABELS and lbl != "none"
def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, Any]:
has_fallacy = bool(obj.get("has_fallacy", False))
fallacies_in = obj.get("fallacies", [])
if not isinstance(fallacies_in, list):
fallacies_in = []
fallacies_out = []
for f in fallacies_in:
if not isinstance(f, dict):
continue
f_type = f.get("type")
if not _is_allowed_label(f_type):
continue
conf = _clamp01(f.get("confidence", 0.5))
conf = float(f"{conf:.2f}")
ev = f.get("evidence_quotes", [])
if not isinstance(ev, list):
ev = []
ev_clean: List[str] = []
for q in ev:
if not isinstance(q, str):
continue
qq = q.strip()
if not qq:
continue
if qq in input_text:
if len(qq) <= 240:
ev_clean.append(qq)
else:
short = qq[:240]
ev_clean.append(short if short in input_text else qq)
rationale = strip_template_sentence(str(f.get("rationale", "")).strip())
fallacies_out.append(
{
"type": f_type,
"confidence": conf,
"evidence_quotes": ev_clean[:3],
"rationale": rationale,
}
)
overall = strip_template_sentence(str(obj.get("overall_explanation", "")).strip())
if len(fallacies_out) == 0:
has_fallacy = False
return {
"has_fallacy": has_fallacy,
"fallacies": fallacies_out,
"overall_explanation": overall,
}
# ----------------------------
# Replace helpers
# ----------------------------
def occurrence_index(text: str, sub: str, occurrence: int) -> int:
if occurrence < 0:
return -1
start = 0
for _ in range(occurrence + 1):
idx = text.find(sub, start)
if idx == -1:
return -1
start = idx + max(1, len(sub))
return idx
def replace_nth(text: str, old: str, new: str, occurrence: int) -> Dict[str, Any]:
idx = occurrence_index(text, old, occurrence)
if idx == -1:
return {"ok": False, "error": "quote_not_found"}
return {
"ok": True,
"rewritten_text": text[:idx] + new + text[idx + len(old) :],
"start_char": idx,
"end_char": idx + len(new),
"old_start_char": idx,
"old_end_char": idx + len(old),
}