recon / src /agents /critic.py
MukulRay's picture
fix: trust summary resilience, Unicode in reason strings, full GitHub README
4f24399
import logging
import json
import re
from datetime import datetime
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_core.messages import SystemMessage, HumanMessage
from src.state import ResearchState, Paper, Verdict
from src.reliability import score_papers, ReliabilityScore
load_dotenv()
logger = logging.getLogger(__name__)
CURRENT_YEAR = datetime.now().year
_llm: ChatGroq | None = None
def get_llm() -> ChatGroq:
global _llm
if _llm is None:
_llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0.1)
return _llm
# ---------------------------------------------------------------------------
# Prompts
# ---------------------------------------------------------------------------
CONTRADICTION_SYSTEM = """You are evaluating whether two ML research papers contradict each other.
Given two paper abstracts, determine if paper B explicitly refutes, supersedes, or contradicts the main claims of paper A.
Output ONLY a JSON object:
{"contradicts": true/false, "reason": "one sentence explanation"}
Be strict β€” only mark contradicts=true if there is a clear, explicit disagreement on a specific technical claim.
Do not mark contradicts=true just because papers propose different methods."""
REWRITE_SYSTEM = """You are rewriting research sub-questions to improve search results.
Given the original sub-questions and a rewrite strategy, output ONLY a JSON array of 2 new questions.
Make questions shorter and more keyword-focused for academic search.
Format: ["question 1", "question 2"]
No preamble."""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _mean_age_months(papers: list[Paper]) -> float:
"""Returns mean age of papers in months."""
ages = [(CURRENT_YEAR - p.year) * 12 for p in papers if p.year and p.year > 0]
return sum(ages) / len(ages) if ages else 9999.0
def _check_contradiction(paper_a: Paper, paper_b: Paper) -> tuple[bool, str]:
"""Ask the LLM whether paper_b contradicts paper_a."""
if not paper_a.abstract or not paper_b.abstract:
return False, "Missing abstract"
prompt = f"""Paper A ({paper_a.year}): {paper_a.title}
Abstract: {paper_a.abstract[:400]}
Paper B ({paper_b.year}): {paper_b.title}
Abstract: {paper_b.abstract[:400]}
Does paper B explicitly contradict or supersede paper A?"""
try:
response = get_llm().invoke([
SystemMessage(content=CONTRADICTION_SYSTEM),
HumanMessage(content=prompt),
])
raw = response.content.strip()
match = re.search(r"\{.*\}", raw, re.DOTALL)
if match:
data = json.loads(match.group())
return bool(data.get("contradicts", False)), data.get("reason", "")
except Exception as e:
logger.warning(f"Contradiction check failed: {e}")
return False, ""
def _detect_contradictions(papers: list[Paper]) -> list[tuple[str, str, str]]:
"""Check top 4 papers for contradictions. Only checks pairs with 2+ year gap."""
contradictions = []
top_papers = papers[:4]
for i, pa in enumerate(top_papers):
for pb in top_papers[i+1:]:
if abs((pa.year or 0) - (pb.year or 0)) < 2:
continue
older = pa if (pa.year or 0) < (pb.year or 0) else pb
newer = pa if (pa.year or 0) > (pb.year or 0) else pb
contradicts, reason = _check_contradiction(older, newer)
if contradicts:
contradictions.append((older.title, newer.title, reason))
logger.info(f"Contradiction: '{older.title[:40]}' vs '{newer.title[:40]}'")
return contradictions
def _rewrite_questions(sub_questions: list[str], strategy: str) -> list[str]:
"""Rewrite sub-questions using LLM based on strategy."""
if not sub_questions:
return []
strategy_instructions = {
"broaden": "Broaden the questions to retrieve more papers. Use more general terms.",
"recent": "Rewrite to focus on very recent work (2023-2025). Add 'recent', 'latest', '2024' keywords.",
"probe_contradiction": "Rewrite to explore the specific disagreement. Focus on the contested claim.",
}
instruction = strategy_instructions.get(strategy, strategy_instructions["broaden"])
questions_text = "\n".join(f"{i+1}. {q}" for i, q in enumerate(sub_questions[:2]))
prompt = f"""Strategy: {instruction}
Original questions:
{questions_text}
Output 2 improved search queries as a JSON array."""
try:
response = get_llm().invoke([
SystemMessage(content=REWRITE_SYSTEM),
HumanMessage(content=prompt),
])
raw = response.content.strip()
match = re.search(r"\[.*\]", raw, re.DOTALL)
if match:
questions = json.loads(match.group())
return [q for q in questions if isinstance(q, str)][:2]
except Exception as e:
logger.warning(f"Question rewrite failed: {e}")
return sub_questions[:2]
# ---------------------------------------------------------------------------
# Critic node
# ---------------------------------------------------------------------------
def critic_node(state: ResearchState) -> ResearchState:
"""
Reads: retrieved_papers, retry_count, sub_questions
Writes: critic_verdict, critic_notes, rewritten_questions,
retry_count, calibration_bin
"""
papers = state.get("retrieved_papers") or []
retry_count = state.get("retry_count", 0)
# FORCED PASS β€” max retries reached
if retry_count >= 2:
logger.info("Critic: max retries reached, forcing PASS")
return {
"critic_verdict": Verdict.FORCED_PASS,
"critic_notes": "Max retries reached. Passing with available evidence.",
"rewritten_questions": [],
"retry_count": retry_count,
"calibration_bin": Verdict.FORCED_PASS,
"paper_reliability_scores": {},
}
# INSUFFICIENT β€” not enough papers
if len(papers) < 3:
logger.info(f"Critic: insufficient papers ({len(papers)})")
rewritten = _rewrite_questions(state.get("sub_questions") or [], "broaden")
return {
"critic_verdict": Verdict.INSUFFICIENT,
"critic_notes": f"Only {len(papers)} papers retrieved. Need at least 3.",
"rewritten_questions": rewritten,
"retry_count": retry_count + 1,
"calibration_bin": Verdict.INSUFFICIENT,
"paper_reliability_scores": {},
}
# INSUFFICIENT β€” scores too low
high_score_papers = [p for p in papers if p.hybrid_score >= 0.40]
if len(high_score_papers) < 3:
logger.info("Critic: insufficient high-score papers")
rewritten = _rewrite_questions(state.get("sub_questions") or [], "broaden")
return {
"critic_verdict": Verdict.INSUFFICIENT,
"critic_notes": "Fewer than 3 papers with hybrid_score >= 0.40.",
"rewritten_questions": rewritten,
"retry_count": retry_count + 1,
"calibration_bin": Verdict.INSUFFICIENT,
"paper_reliability_scores": {},
}
# --- Phase 2.4: Compute edge reliability scores for all papers ---
original_query = state.get("original_query", "")
try:
reliability_scores = score_papers(papers, query=original_query, use_llm=True)
if not reliability_scores:
logger.warning("score_papers() returned empty dict β€” falling back to no reliability scores")
except Exception as e:
logger.warning(f"score_papers() failed entirely: {e} β€” trust summary will be skipped")
reliability_scores = {}
# --- Run STALE and CONTRADICTED checks in parallel (both always run) ---
mean_age = _mean_age_months(papers)
# v2: STALE = low reliability across retrieved papers (not just mean age)
if reliability_scores:
mean_reliability = sum(rs.score for rs in reliability_scores.values()) / len(reliability_scores)
is_stale = mean_reliability < 0.50
else:
# fallback to v1 age threshold if scorer failed
is_stale = mean_age > 24
contradictions = _detect_contradictions(papers)
is_contradicted = len(contradictions) > 0
# --- Combine signals: CONTRADICTED wins when both fire ---
if is_contradicted and is_stale:
verdict = Verdict.CONTRADICTED
contradiction_details = "; ".join(f"'{c[0]}' vs '{c[1]}': {c[2]}" for c in contradictions)
notes = f"CONTRADICTED (also stale, mean age {mean_age:.0f} months). Contradictions found: {contradiction_details}"
strategy = "probe_contradiction"
elif is_contradicted:
verdict = Verdict.CONTRADICTED
contradiction_details = "; ".join(f"'{c[0]}' vs '{c[1]}': {c[2]}" for c in contradictions)
notes = f"Contradictions found: {contradiction_details}"
strategy = "probe_contradiction"
elif is_stale:
verdict = Verdict.STALE
notes = f"Evidence is stale (mean age {mean_age:.0f} months > 24 month threshold)"
strategy = "recent"
else:
# PASS β€” all checks clear
return {
"critic_verdict": Verdict.PASS,
"critic_notes": f"Evidence passes all checks (mean age {mean_age:.0f} months, {len(papers)} papers, no contradictions detected)",
"retry_count": retry_count,
"rewritten_questions": [],
"calibration_bin": Verdict.PASS,
"paper_reliability_scores": {pid: rs.__dict__ for pid, rs in reliability_scores.items()},
}
# --- Non-PASS path: rewrite questions and return ---
sub_questions = state.get("sub_questions") or []
rewritten = _rewrite_questions(sub_questions, strategy)
return {
"critic_verdict": verdict,
"critic_notes": notes,
"rewritten_questions": rewritten,
"retry_count": retry_count + 1,
"calibration_bin": verdict,
"paper_reliability_scores": {pid: rs.__dict__ for pid, rs in reliability_scores.items()},
}