"""Shared helpers: LLM completion, chunking, formatting, validation.""" from __future__ import annotations import re import threading import time from typing import TYPE_CHECKING from . import config _meter_lock = threading.Lock() if TYPE_CHECKING: # avoid importing heavy modules at runtime just for typing from .tools import Paper def _record_usage(meter: dict | None, in_tok: int, out_tok: int) -> None: if meter is None: return with _meter_lock: # called from parallel reader threads meter["llm_calls"] = meter.get("llm_calls", 0) + 1 meter["in_tok"] = meter.get("in_tok", 0) + (in_tok or 0) meter["out_tok"] = meter.get("out_tok", 0) + (out_tok or 0) def complete( prompt: str, system: str | None = None, max_tokens: int = 4096, meter: dict | None = None, model: str | None = None, ) -> str: """Send a single prompt to the configured provider and return the text. Retries once on transient failure, then re-raises. ``model`` overrides the configured model for this call; ``meter`` accumulates token usage. """ client = config.get_llm_client() model = model or config.MODEL last_err: Exception | None = None for attempt in range(2): try: if config.PROVIDER == "anthropic": kwargs = { "model": model, "max_tokens": max_tokens, "messages": [{"role": "user", "content": prompt}], } if system: kwargs["system"] = system resp = client.messages.create(**kwargs) u = getattr(resp, "usage", None) _record_usage(meter, getattr(u, "input_tokens", 0), getattr(u, "output_tokens", 0)) return "".join( block.text for block in resp.content if block.type == "text" ).strip() # openai / gemini (OpenAI-compatible) messages = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) resp = client.chat.completions.create( model=model, max_tokens=max_tokens, messages=messages, ) u = getattr(resp, "usage", None) _record_usage(meter, getattr(u, "prompt_tokens", 0), getattr(u, "completion_tokens", 0)) return (resp.choices[0].message.content or "").strip() except Exception as err: # noqa: BLE001 - retry-once then surface last_err = err if attempt == 0: time.sleep(1.5) continue raise # Unreachable, but keeps type checkers happy. raise RuntimeError(f"LLM completion failed: {last_err}") # Approx USD per 1M tokens (input, output). Used only for a rough estimate. _PRICING = { "gemini-2.5-flash": (0.30, 2.50), "gemini-2.5-pro": (1.25, 10.0), "gpt-4o": (2.50, 10.0), "claude-sonnet": (3.0, 15.0), "claude-opus": (15.0, 75.0), } def cost_report(meter: dict, model: str | None = None) -> dict: """Turn an accumulated meter into a {calls, in_tok, out_tok, est_cost_usd} dict.""" model = (model or config.MODEL).lower() rate = next((v for k, v in _PRICING.items() if k in model), (0.50, 1.50)) in_tok = meter.get("in_tok", 0) out_tok = meter.get("out_tok", 0) est = (in_tok / 1_000_000) * rate[0] + (out_tok / 1_000_000) * rate[1] return { "llm_calls": meter.get("llm_calls", 0), "in_tok": in_tok, "out_tok": out_tok, "est_cost_usd": round(est, 4), } def chunk_text(text: str, size: int = 3000) -> list[str]: """Split text into ~``size``-char chunks, preferring paragraph breaks. Paragraphs (split on blank lines) are packed greedily so that section coherence is preserved where possible. A single oversized paragraph is hard-split. """ paragraphs = re.split(r"\n\s*\n", text) chunks: list[str] = [] current = "" for para in paragraphs: para = para.strip() if not para: continue if len(para) > size: if current: chunks.append(current) current = "" for i in range(0, len(para), size): chunks.append(para[i : i + size]) continue if len(current) + len(para) + 2 > size: if current: chunks.append(current) current = para else: current = f"{current}\n\n{para}" if current else para if current: chunks.append(current) return chunks def sanitize_filename(text: str, max_len: int = 80) -> str: """Turn an arbitrary topic string into a safe filename stem.""" slug = re.sub(r"[^\w\s-]", "", text).strip().lower() slug = re.sub(r"[\s_-]+", "-", slug) return (slug[:max_len] or "review").strip("-") def format_papers_for_synthesis(papers: list["Paper"]) -> str: """Render papers into a compact block for the synthesis prompt.""" blocks: list[str] = [] for p in papers: authors = ", ".join(p.authors[:3]) + (" et al." if len(p.authors) > 3 else "") parts = [f"### {p.title} ({p.year})", f"Authors: {authors or 'Unknown'}"] if p.claims: parts.append("Claims:\n" + "\n".join(f" - {c}" for c in p.claims)) if p.methods: parts.append("Methods:\n" + "\n".join(f" - {m}" for m in p.methods)) if p.results: parts.append("Results:\n" + "\n".join(f" - {r}" for r in p.results)) if not (p.claims or p.methods or p.results) and p.abstract: parts.append(f"Abstract: {p.abstract[:500]}") blocks.append("\n".join(parts)) return "\n\n".join(blocks) def paper_url(p: "Paper") -> str: """Best human-facing URL for a paper.""" if p.arxiv_id: return f"https://arxiv.org/abs/{p.arxiv_id}" if p.source == "arxiv": return f"https://arxiv.org/abs/{p.id}" return p.pdf_url or "" def paper_meta(p: "Paper") -> dict: """Serializable metadata for the UI bibliography.""" authors = ", ".join(p.authors[:3]) + (" et al." if len(p.authors) > 3 else "") return { "title": p.title, "authors": authors or "Unknown", "year": p.year, "url": paper_url(p), "source": p.source, "citation_count": p.citation_count, "read_from": p.read_from, } def references_markdown(papers: list["Paper"]) -> str: """A numbered, linked '## References' section for provenance.""" if not papers: return "" lines = ["", "## References", ""] for i, p in enumerate(papers, 1): m = paper_meta(p) link = f"[{p.title}]({m['url']})" if m["url"] else p.title cites = f" · {p.citation_count} citations" if p.citation_count else "" abstract_note = " · abstract only" if p.read_from == "abstract" else "" lines.append(f"{i}. {link} — {m['authors']} ({p.year}){cites}{abstract_note}") return "\n".join(lines) def validate_review(review: str, papers: list["Paper"]) -> dict: """Sanity-check a generated review. Confirms section headers exist and that ``[Author, Year]`` citations reference an author surname that appears in the paper set. Returns ``{"valid": bool, "issues": [str, ...]}``. """ issues: list[str] = [] headers = re.findall(r"^#{1,3}\s+\S", review, flags=re.MULTILINE) if len(headers) < 2: issues.append("Review has fewer than 2 section headers.") # Known author surnames (last token of each author name), lowercased, and # the set of paper years — used to spot citations to papers not in the set. known: set[str] = set() known_years: set[str] = set() for p in papers: if p.year: known_years.add(str(p.year)) for author in p.authors: surname = author.strip().split()[-1].lower() if author.strip() else "" if surname: known.add(surname) citations = re.findall(r"\[([^\]]+?),\s*(\d{4})\]", review) if not citations: issues.append("No [Author, Year] citations found.") hallucinated = 0 for author_part, year in citations: token = re.split(r"\s+|&|,", author_part.strip())[0].lower() token = token.replace("et", "").strip() bad_author = bool(token) and bool(known) and token not in known bad_year = bool(known_years) and year not in known_years if bad_author or bad_year: hallucinated += 1 issues.append(f"Citation may not match any paper: [{author_part}, {year}]") # De-duplicate while preserving order. seen: set[str] = set() unique_issues = [i for i in issues if not (i in seen or seen.add(i))] return {"valid": not unique_issues, "issues": unique_issues}