research-agent / src /utils.py
abhid1234's picture
relevance filter, parallel reads, bibliography, cache, share/export, model/recency
3039836 verified
Raw
History Blame Contribute Delete
8.95 kB
"""Shared helpers: LLM completion, chunking, formatting, validation."""
from __future__ import annotations
import re
import threading
import time
from typing import TYPE_CHECKING
from . import config
_meter_lock = threading.Lock()
if TYPE_CHECKING: # avoid importing heavy modules at runtime just for typing
from .tools import Paper
def _record_usage(meter: dict | None, in_tok: int, out_tok: int) -> None:
if meter is None:
return
with _meter_lock: # called from parallel reader threads
meter["llm_calls"] = meter.get("llm_calls", 0) + 1
meter["in_tok"] = meter.get("in_tok", 0) + (in_tok or 0)
meter["out_tok"] = meter.get("out_tok", 0) + (out_tok or 0)
def complete(
prompt: str,
system: str | None = None,
max_tokens: int = 4096,
meter: dict | None = None,
model: str | None = None,
) -> str:
"""Send a single prompt to the configured provider and return the text.
Retries once on transient failure, then re-raises. ``model`` overrides the
configured model for this call; ``meter`` accumulates token usage.
"""
client = config.get_llm_client()
model = model or config.MODEL
last_err: Exception | None = None
for attempt in range(2):
try:
if config.PROVIDER == "anthropic":
kwargs = {
"model": model,
"max_tokens": max_tokens,
"messages": [{"role": "user", "content": prompt}],
}
if system:
kwargs["system"] = system
resp = client.messages.create(**kwargs)
u = getattr(resp, "usage", None)
_record_usage(meter, getattr(u, "input_tokens", 0), getattr(u, "output_tokens", 0))
return "".join(
block.text for block in resp.content if block.type == "text"
).strip()
# openai / gemini (OpenAI-compatible)
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
resp = client.chat.completions.create(
model=model,
max_tokens=max_tokens,
messages=messages,
)
u = getattr(resp, "usage", None)
_record_usage(meter, getattr(u, "prompt_tokens", 0), getattr(u, "completion_tokens", 0))
return (resp.choices[0].message.content or "").strip()
except Exception as err: # noqa: BLE001 - retry-once then surface
last_err = err
if attempt == 0:
time.sleep(1.5)
continue
raise
# Unreachable, but keeps type checkers happy.
raise RuntimeError(f"LLM completion failed: {last_err}")
# Approx USD per 1M tokens (input, output). Used only for a rough estimate.
_PRICING = {
"gemini-2.5-flash": (0.30, 2.50),
"gemini-2.5-pro": (1.25, 10.0),
"gpt-4o": (2.50, 10.0),
"claude-sonnet": (3.0, 15.0),
"claude-opus": (15.0, 75.0),
}
def cost_report(meter: dict, model: str | None = None) -> dict:
"""Turn an accumulated meter into a {calls, in_tok, out_tok, est_cost_usd} dict."""
model = (model or config.MODEL).lower()
rate = next((v for k, v in _PRICING.items() if k in model), (0.50, 1.50))
in_tok = meter.get("in_tok", 0)
out_tok = meter.get("out_tok", 0)
est = (in_tok / 1_000_000) * rate[0] + (out_tok / 1_000_000) * rate[1]
return {
"llm_calls": meter.get("llm_calls", 0),
"in_tok": in_tok,
"out_tok": out_tok,
"est_cost_usd": round(est, 4),
}
def chunk_text(text: str, size: int = 3000) -> list[str]:
"""Split text into ~``size``-char chunks, preferring paragraph breaks.
Paragraphs (split on blank lines) are packed greedily so that section
coherence is preserved where possible. A single oversized paragraph is
hard-split.
"""
paragraphs = re.split(r"\n\s*\n", text)
chunks: list[str] = []
current = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
if len(para) > size:
if current:
chunks.append(current)
current = ""
for i in range(0, len(para), size):
chunks.append(para[i : i + size])
continue
if len(current) + len(para) + 2 > size:
if current:
chunks.append(current)
current = para
else:
current = f"{current}\n\n{para}" if current else para
if current:
chunks.append(current)
return chunks
def sanitize_filename(text: str, max_len: int = 80) -> str:
"""Turn an arbitrary topic string into a safe filename stem."""
slug = re.sub(r"[^\w\s-]", "", text).strip().lower()
slug = re.sub(r"[\s_-]+", "-", slug)
return (slug[:max_len] or "review").strip("-")
def format_papers_for_synthesis(papers: list["Paper"]) -> str:
"""Render papers into a compact block for the synthesis prompt."""
blocks: list[str] = []
for p in papers:
authors = ", ".join(p.authors[:3]) + (" et al." if len(p.authors) > 3 else "")
parts = [f"### {p.title} ({p.year})", f"Authors: {authors or 'Unknown'}"]
if p.claims:
parts.append("Claims:\n" + "\n".join(f" - {c}" for c in p.claims))
if p.methods:
parts.append("Methods:\n" + "\n".join(f" - {m}" for m in p.methods))
if p.results:
parts.append("Results:\n" + "\n".join(f" - {r}" for r in p.results))
if not (p.claims or p.methods or p.results) and p.abstract:
parts.append(f"Abstract: {p.abstract[:500]}")
blocks.append("\n".join(parts))
return "\n\n".join(blocks)
def paper_url(p: "Paper") -> str:
"""Best human-facing URL for a paper."""
if p.arxiv_id:
return f"https://arxiv.org/abs/{p.arxiv_id}"
if p.source == "arxiv":
return f"https://arxiv.org/abs/{p.id}"
return p.pdf_url or ""
def paper_meta(p: "Paper") -> dict:
"""Serializable metadata for the UI bibliography."""
authors = ", ".join(p.authors[:3]) + (" et al." if len(p.authors) > 3 else "")
return {
"title": p.title,
"authors": authors or "Unknown",
"year": p.year,
"url": paper_url(p),
"source": p.source,
"citation_count": p.citation_count,
"read_from": p.read_from,
}
def references_markdown(papers: list["Paper"]) -> str:
"""A numbered, linked '## References' section for provenance."""
if not papers:
return ""
lines = ["", "## References", ""]
for i, p in enumerate(papers, 1):
m = paper_meta(p)
link = f"[{p.title}]({m['url']})" if m["url"] else p.title
cites = f" · {p.citation_count} citations" if p.citation_count else ""
abstract_note = " · abstract only" if p.read_from == "abstract" else ""
lines.append(f"{i}. {link}{m['authors']} ({p.year}){cites}{abstract_note}")
return "\n".join(lines)
def validate_review(review: str, papers: list["Paper"]) -> dict:
"""Sanity-check a generated review.
Confirms section headers exist and that ``[Author, Year]`` citations
reference an author surname that appears in the paper set. Returns
``{"valid": bool, "issues": [str, ...]}``.
"""
issues: list[str] = []
headers = re.findall(r"^#{1,3}\s+\S", review, flags=re.MULTILINE)
if len(headers) < 2:
issues.append("Review has fewer than 2 section headers.")
# Known author surnames (last token of each author name), lowercased, and
# the set of paper years — used to spot citations to papers not in the set.
known: set[str] = set()
known_years: set[str] = set()
for p in papers:
if p.year:
known_years.add(str(p.year))
for author in p.authors:
surname = author.strip().split()[-1].lower() if author.strip() else ""
if surname:
known.add(surname)
citations = re.findall(r"\[([^\]]+?),\s*(\d{4})\]", review)
if not citations:
issues.append("No [Author, Year] citations found.")
hallucinated = 0
for author_part, year in citations:
token = re.split(r"\s+|&|,", author_part.strip())[0].lower()
token = token.replace("et", "").strip()
bad_author = bool(token) and bool(known) and token not in known
bad_year = bool(known_years) and year not in known_years
if bad_author or bad_year:
hallucinated += 1
issues.append(f"Citation may not match any paper: [{author_part}, {year}]")
# De-duplicate while preserving order.
seen: set[str] = set()
unique_issues = [i for i in issues if not (i in seen or seen.add(i))]
return {"valid": not unique_issues, "issues": unique_issues}