Spaces:
Sleeping
Sleeping
| """ | |
| agent.py | |
| Mistral wrapper. The only file that calls the LLM. No analytical branching | |
| beyond what's in prompts.py — every interpretive decision is delegated to | |
| the model and reviewed by the researcher in the UI. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import re | |
| from typing import Any | |
| import prompts | |
| MISTRAL_MODEL = os.environ.get("MISTRAL_MODEL", "mistral-large-latest") | |
| # ---------------------------------------------------------- low-level | |
| def _call_mistral(user_message: str, system_message: str | None = None, | |
| max_tokens: int = 800, temperature: float = 0.0) -> str: | |
| if "MISTRAL_API_KEY" not in os.environ: | |
| raise RuntimeError( | |
| "MISTRAL_API_KEY is not set. On Hugging Face Spaces, add it under " | |
| "Settings -> Variables and secrets -> Secrets." | |
| ) | |
| from mistralai import Mistral | |
| client = Mistral(api_key=os.environ["MISTRAL_API_KEY"]) | |
| messages = [] | |
| if system_message: | |
| messages.append({"role": "system", "content": system_message}) | |
| messages.append({"role": "user", "content": user_message}) | |
| response = client.chat.complete( | |
| model=MISTRAL_MODEL, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| def _extract_json(text: str) -> dict: | |
| cleaned = text.strip() | |
| cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned) | |
| cleaned = re.sub(r"\s*```$", "", cleaned) | |
| try: | |
| return json.loads(cleaned) | |
| except json.JSONDecodeError as e: | |
| raise ValueError(f"Mistral returned non-JSON: {text[:300]}") from e | |
| # ---------------------------------------------------------- embedding agent ops | |
| def label_cluster(cluster_id: int, n_papers: int, | |
| top_papers: list[dict], | |
| domain_context: str | None = None) -> dict: | |
| user_prompt = prompts.prompt_label_cluster( | |
| cluster_id=cluster_id, n_papers=n_papers, | |
| top_papers=top_papers, domain_context=domain_context, | |
| ) | |
| text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=400) | |
| return _extract_json(text) | |
| def validate_label(proposed_label: str, cluster_id: int, | |
| extra_papers: list[dict]) -> dict: | |
| user_prompt = prompts.prompt_validate_label( | |
| proposed_label=proposed_label, cluster_id=cluster_id, | |
| extra_papers=extra_papers, | |
| ) | |
| text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=300) | |
| return _extract_json(text) | |
| def recommend_parameters(n_papers: int, candidate_summaries: list[dict]) -> dict: | |
| user_prompt = prompts.prompt_recommend_parameters( | |
| n_papers=n_papers, candidate_summaries=candidate_summaries, | |
| ) | |
| text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=400) | |
| return _extract_json(text) | |
| def freeform_question(question: str, cluster_summaries: list[dict]) -> str: | |
| user_prompt = prompts.prompt_freeform_question( | |
| question=question, cluster_summaries=cluster_summaries, | |
| ) | |
| return _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=600) | |
| # ---------------------------------------------------------- TCCM agent ops | |
| def tccm_marginal_review(paper: dict, | |
| regex_signal_counts: dict, | |
| fired_terms: dict) -> dict: | |
| """ | |
| Read a MARGINAL paper and return {verdict, rationale, confidence}. | |
| The researcher sees this in the UI and can accept it, override it, or | |
| leave it MARGINAL. | |
| """ | |
| user_prompt = prompts.prompt_tccm_marginal( | |
| paper=paper, | |
| regex_signal_counts=regex_signal_counts, | |
| fired_terms=fired_terms, | |
| ) | |
| text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=400) | |
| parsed = _extract_json(text) | |
| # Light defensive validation — verdict must be one of the three | |
| if parsed.get("verdict") not in ("INCLUDE", "EXCLUDE", "MARGINAL"): | |
| parsed["verdict"] = "MARGINAL" | |
| return parsed | |