Spaces:
Running
Running
| """Shared helpers: LLM completion, chunking, formatting, validation.""" | |
| from __future__ import annotations | |
| import re | |
| import threading | |
| import time | |
| from typing import TYPE_CHECKING | |
| from . import config | |
| _meter_lock = threading.Lock() | |
| if TYPE_CHECKING: # avoid importing heavy modules at runtime just for typing | |
| from .tools import Paper | |
| def _record_usage(meter: dict | None, in_tok: int, out_tok: int) -> None: | |
| if meter is None: | |
| return | |
| with _meter_lock: # called from parallel reader threads | |
| meter["llm_calls"] = meter.get("llm_calls", 0) + 1 | |
| meter["in_tok"] = meter.get("in_tok", 0) + (in_tok or 0) | |
| meter["out_tok"] = meter.get("out_tok", 0) + (out_tok or 0) | |
| def complete( | |
| prompt: str, | |
| system: str | None = None, | |
| max_tokens: int = 4096, | |
| meter: dict | None = None, | |
| model: str | None = None, | |
| ) -> str: | |
| """Send a single prompt to the configured provider and return the text. | |
| Retries once on transient failure, then re-raises. ``model`` overrides the | |
| configured model for this call; ``meter`` accumulates token usage. | |
| """ | |
| client = config.get_llm_client() | |
| model = model or config.MODEL | |
| last_err: Exception | None = None | |
| for attempt in range(2): | |
| try: | |
| if config.PROVIDER == "anthropic": | |
| kwargs = { | |
| "model": model, | |
| "max_tokens": max_tokens, | |
| "messages": [{"role": "user", "content": prompt}], | |
| } | |
| if system: | |
| kwargs["system"] = system | |
| resp = client.messages.create(**kwargs) | |
| u = getattr(resp, "usage", None) | |
| _record_usage(meter, getattr(u, "input_tokens", 0), getattr(u, "output_tokens", 0)) | |
| return "".join( | |
| block.text for block in resp.content if block.type == "text" | |
| ).strip() | |
| # openai / gemini (OpenAI-compatible) | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| messages.append({"role": "user", "content": prompt}) | |
| resp = client.chat.completions.create( | |
| model=model, | |
| max_tokens=max_tokens, | |
| messages=messages, | |
| ) | |
| u = getattr(resp, "usage", None) | |
| _record_usage(meter, getattr(u, "prompt_tokens", 0), getattr(u, "completion_tokens", 0)) | |
| return (resp.choices[0].message.content or "").strip() | |
| except Exception as err: # noqa: BLE001 - retry-once then surface | |
| last_err = err | |
| if attempt == 0: | |
| time.sleep(1.5) | |
| continue | |
| raise | |
| # Unreachable, but keeps type checkers happy. | |
| raise RuntimeError(f"LLM completion failed: {last_err}") | |
| # Approx USD per 1M tokens (input, output). Used only for a rough estimate. | |
| _PRICING = { | |
| "gemini-2.5-flash": (0.30, 2.50), | |
| "gemini-2.5-pro": (1.25, 10.0), | |
| "gpt-4o": (2.50, 10.0), | |
| "claude-sonnet": (3.0, 15.0), | |
| "claude-opus": (15.0, 75.0), | |
| } | |
| def cost_report(meter: dict, model: str | None = None) -> dict: | |
| """Turn an accumulated meter into a {calls, in_tok, out_tok, est_cost_usd} dict.""" | |
| model = (model or config.MODEL).lower() | |
| rate = next((v for k, v in _PRICING.items() if k in model), (0.50, 1.50)) | |
| in_tok = meter.get("in_tok", 0) | |
| out_tok = meter.get("out_tok", 0) | |
| est = (in_tok / 1_000_000) * rate[0] + (out_tok / 1_000_000) * rate[1] | |
| return { | |
| "llm_calls": meter.get("llm_calls", 0), | |
| "in_tok": in_tok, | |
| "out_tok": out_tok, | |
| "est_cost_usd": round(est, 4), | |
| } | |
| def chunk_text(text: str, size: int = 3000) -> list[str]: | |
| """Split text into ~``size``-char chunks, preferring paragraph breaks. | |
| Paragraphs (split on blank lines) are packed greedily so that section | |
| coherence is preserved where possible. A single oversized paragraph is | |
| hard-split. | |
| """ | |
| paragraphs = re.split(r"\n\s*\n", text) | |
| chunks: list[str] = [] | |
| current = "" | |
| for para in paragraphs: | |
| para = para.strip() | |
| if not para: | |
| continue | |
| if len(para) > size: | |
| if current: | |
| chunks.append(current) | |
| current = "" | |
| for i in range(0, len(para), size): | |
| chunks.append(para[i : i + size]) | |
| continue | |
| if len(current) + len(para) + 2 > size: | |
| if current: | |
| chunks.append(current) | |
| current = para | |
| else: | |
| current = f"{current}\n\n{para}" if current else para | |
| if current: | |
| chunks.append(current) | |
| return chunks | |
| def sanitize_filename(text: str, max_len: int = 80) -> str: | |
| """Turn an arbitrary topic string into a safe filename stem.""" | |
| slug = re.sub(r"[^\w\s-]", "", text).strip().lower() | |
| slug = re.sub(r"[\s_-]+", "-", slug) | |
| return (slug[:max_len] or "review").strip("-") | |
| def format_papers_for_synthesis(papers: list["Paper"]) -> str: | |
| """Render papers into a compact block for the synthesis prompt.""" | |
| blocks: list[str] = [] | |
| for p in papers: | |
| authors = ", ".join(p.authors[:3]) + (" et al." if len(p.authors) > 3 else "") | |
| parts = [f"### {p.title} ({p.year})", f"Authors: {authors or 'Unknown'}"] | |
| if p.claims: | |
| parts.append("Claims:\n" + "\n".join(f" - {c}" for c in p.claims)) | |
| if p.methods: | |
| parts.append("Methods:\n" + "\n".join(f" - {m}" for m in p.methods)) | |
| if p.results: | |
| parts.append("Results:\n" + "\n".join(f" - {r}" for r in p.results)) | |
| if not (p.claims or p.methods or p.results) and p.abstract: | |
| parts.append(f"Abstract: {p.abstract[:500]}") | |
| blocks.append("\n".join(parts)) | |
| return "\n\n".join(blocks) | |
| def paper_url(p: "Paper") -> str: | |
| """Best human-facing URL for a paper.""" | |
| if p.arxiv_id: | |
| return f"https://arxiv.org/abs/{p.arxiv_id}" | |
| if p.source == "arxiv": | |
| return f"https://arxiv.org/abs/{p.id}" | |
| return p.pdf_url or "" | |
| def paper_meta(p: "Paper") -> dict: | |
| """Serializable metadata for the UI bibliography.""" | |
| authors = ", ".join(p.authors[:3]) + (" et al." if len(p.authors) > 3 else "") | |
| return { | |
| "title": p.title, | |
| "authors": authors or "Unknown", | |
| "year": p.year, | |
| "url": paper_url(p), | |
| "source": p.source, | |
| "citation_count": p.citation_count, | |
| "read_from": p.read_from, | |
| } | |
| def references_markdown(papers: list["Paper"]) -> str: | |
| """A numbered, linked '## References' section for provenance.""" | |
| if not papers: | |
| return "" | |
| lines = ["", "## References", ""] | |
| for i, p in enumerate(papers, 1): | |
| m = paper_meta(p) | |
| link = f"[{p.title}]({m['url']})" if m["url"] else p.title | |
| cites = f" · {p.citation_count} citations" if p.citation_count else "" | |
| abstract_note = " · abstract only" if p.read_from == "abstract" else "" | |
| lines.append(f"{i}. {link} — {m['authors']} ({p.year}){cites}{abstract_note}") | |
| return "\n".join(lines) | |
| def validate_review(review: str, papers: list["Paper"]) -> dict: | |
| """Sanity-check a generated review. | |
| Confirms section headers exist and that ``[Author, Year]`` citations | |
| reference an author surname that appears in the paper set. Returns | |
| ``{"valid": bool, "issues": [str, ...]}``. | |
| """ | |
| issues: list[str] = [] | |
| headers = re.findall(r"^#{1,3}\s+\S", review, flags=re.MULTILINE) | |
| if len(headers) < 2: | |
| issues.append("Review has fewer than 2 section headers.") | |
| # Known author surnames (last token of each author name), lowercased, and | |
| # the set of paper years — used to spot citations to papers not in the set. | |
| known: set[str] = set() | |
| known_years: set[str] = set() | |
| for p in papers: | |
| if p.year: | |
| known_years.add(str(p.year)) | |
| for author in p.authors: | |
| surname = author.strip().split()[-1].lower() if author.strip() else "" | |
| if surname: | |
| known.add(surname) | |
| citations = re.findall(r"\[([^\]]+?),\s*(\d{4})\]", review) | |
| if not citations: | |
| issues.append("No [Author, Year] citations found.") | |
| hallucinated = 0 | |
| for author_part, year in citations: | |
| token = re.split(r"\s+|&|,", author_part.strip())[0].lower() | |
| token = token.replace("et", "").strip() | |
| bad_author = bool(token) and bool(known) and token not in known | |
| bad_year = bool(known_years) and year not in known_years | |
| if bad_author or bad_year: | |
| hallucinated += 1 | |
| issues.append(f"Citation may not match any paper: [{author_part}, {year}]") | |
| # De-duplicate while preserving order. | |
| seen: set[str] = set() | |
| unique_issues = [i for i in issues if not (i in seen or seen.add(i))] | |
| return {"valid": not unique_issues, "issues": unique_issues} | |