TOPICMODELLING / agent.py
Milind Kamat
Add TCCM screening pipeline; migrate to Supabase; integrated UI
8d75855
"""
agent.py
Mistral wrapper. The only file that calls the LLM. No analytical branching
beyond what's in prompts.py — every interpretive decision is delegated to
the model and reviewed by the researcher in the UI.
"""
from __future__ import annotations
import os
import json
import re
from typing import Any
import prompts
MISTRAL_MODEL = os.environ.get("MISTRAL_MODEL", "mistral-large-latest")
# ---------------------------------------------------------- low-level
def _call_mistral(user_message: str, system_message: str | None = None,
max_tokens: int = 800, temperature: float = 0.0) -> str:
if "MISTRAL_API_KEY" not in os.environ:
raise RuntimeError(
"MISTRAL_API_KEY is not set. On Hugging Face Spaces, add it under "
"Settings -> Variables and secrets -> Secrets."
)
from mistralai import Mistral
client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
messages.append({"role": "user", "content": user_message})
response = client.chat.complete(
model=MISTRAL_MODEL,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
return response.choices[0].message.content.strip()
def _extract_json(text: str) -> dict:
cleaned = text.strip()
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned)
cleaned = re.sub(r"\s*```$", "", cleaned)
try:
return json.loads(cleaned)
except json.JSONDecodeError as e:
raise ValueError(f"Mistral returned non-JSON: {text[:300]}") from e
# ---------------------------------------------------------- embedding agent ops
def label_cluster(cluster_id: int, n_papers: int,
top_papers: list[dict],
domain_context: str | None = None) -> dict:
user_prompt = prompts.prompt_label_cluster(
cluster_id=cluster_id, n_papers=n_papers,
top_papers=top_papers, domain_context=domain_context,
)
text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=400)
return _extract_json(text)
def validate_label(proposed_label: str, cluster_id: int,
extra_papers: list[dict]) -> dict:
user_prompt = prompts.prompt_validate_label(
proposed_label=proposed_label, cluster_id=cluster_id,
extra_papers=extra_papers,
)
text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=300)
return _extract_json(text)
def recommend_parameters(n_papers: int, candidate_summaries: list[dict]) -> dict:
user_prompt = prompts.prompt_recommend_parameters(
n_papers=n_papers, candidate_summaries=candidate_summaries,
)
text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=400)
return _extract_json(text)
def freeform_question(question: str, cluster_summaries: list[dict]) -> str:
user_prompt = prompts.prompt_freeform_question(
question=question, cluster_summaries=cluster_summaries,
)
return _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=600)
# ---------------------------------------------------------- TCCM agent ops
def tccm_marginal_review(paper: dict,
regex_signal_counts: dict,
fired_terms: dict) -> dict:
"""
Read a MARGINAL paper and return {verdict, rationale, confidence}.
The researcher sees this in the UI and can accept it, override it, or
leave it MARGINAL.
"""
user_prompt = prompts.prompt_tccm_marginal(
paper=paper,
regex_signal_counts=regex_signal_counts,
fired_terms=fired_terms,
)
text = _call_mistral(user_prompt, prompts.SYSTEM_PROMPT, max_tokens=400)
parsed = _extract_json(text)
# Light defensive validation — verdict must be one of the three
if parsed.get("verdict") not in ("INCLUDE", "EXCLUDE", "MARGINAL"):
parsed["verdict"] = "MARGINAL"
return parsed