""" agent.py Mistral wrapper. The only file that calls the LLM. No analytical branching beyond what's in prompts.py — every interpretive decision is delegated to the model and reviewed by the researcher in the UI. """ from __future__ import annotations import os import json import re from typing import Any import prompts MISTRAL_MODEL = os.environ.get("MISTRAL_MODEL", "mistral-large-latest") # ---------------------------------------------------------- low-level def _call_mistral(user_message: str, system_message: str | None = None, max_tokens: int = 800, temperature: float = 0.0) -> str: if "MISTRAL_API_KEY" not in os.environ: raise RuntimeError( "MISTRAL_API_KEY is not set. On Hugging Face Spaces, add it under " "Settings -> Variables and secrets -> Secrets." ) from mistralai import Mistral client = Mistral(api_key=os.environ["MISTRAL_API_KEY"]) messages = [] if system_message: messages.append({"role": "system", "content": system_message}) messages.append({"role": "user", "content": user_message}) response = client.chat.complete( model=MISTRAL_MODEL, messages=messages, max_tokens=max_tokens, temperature=temperature, ) return response.choices[0].message.content.strip() def _extract_json(text: str) -> dict: cleaned = text.strip() cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned) cleaned = re.sub(r"\s*```$", "", cleaned) try: return json.loads(cleaned) except json.JSONDecodeError as e: raise ValueError(f"Mistral returned non-JSON: {text[:300]}") from e # ---------------------------------------------------------- embedding agent ops def label_cluster(cluster_id: int, n_papers: int, top_papers: list[dict], domain_context: str | None = None) -> dict: user_prompt = prompts.prompt_label_cluster( cluster_id=cluster_id, n_papers=n_papers, top_papers=top_papers, domain_context=domain_context, ) text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=400) return _extract_json(text) def validate_label(proposed_label: str, cluster_id: int, extra_papers: list[dict]) -> dict: user_prompt = prompts.prompt_validate_label( proposed_label=proposed_label, cluster_id=cluster_id, extra_papers=extra_papers, ) text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=300) return _extract_json(text) def recommend_parameters(n_papers: int, candidate_summaries: list[dict]) -> dict: user_prompt = prompts.prompt_recommend_parameters( n_papers=n_papers, candidate_summaries=candidate_summaries, ) text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=400) return _extract_json(text) def freeform_question(question: str, cluster_summaries: list[dict]) -> str: user_prompt = prompts.prompt_freeform_question( question=question, cluster_summaries=cluster_summaries, ) return _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=600) # ---------------------------------------------------------- TCCM agent ops def tccm_marginal_review(paper: dict, regex_signal_counts: dict, fired_terms: dict) -> dict: """ Read a MARGINAL paper and return {verdict, rationale, confidence}. The researcher sees this in the UI and can accept it, override it, or leave it MARGINAL. """ user_prompt = prompts.prompt_tccm_marginal( paper=paper, regex_signal_counts=regex_signal_counts, fired_terms=fired_terms, ) text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=400) parsed = _extract_json(text) # Light defensive validation — verdict must be one of the three if parsed.get("verdict") not in ("INCLUDE", "EXCLUDE", "MARGINAL"): parsed["verdict"] = "MARGINAL" return parsed