""" classifier.py — two LLM calls for the two-stage flow. extract_schema(study_context, preview) → SchemaSummary (no label) pick_label (schema, study_context) → (label, reasoning) The orchestrator (statlens_run.py) drives both. Each call is independent — the user can edit the schema between them, and the label decision uses the edited schema. """ from __future__ import annotations import json import re from dataclasses import dataclass, field import httpx from .prompts import ( LABELS, build_extract_messages, build_label_messages, ) from .schema_spec import coerce_and_validate, default_schema @dataclass class ExtractResult: """Output of stage 1: just the schema, no label.""" schema: dict raw: str schema_warnings: list[str] = field(default_factory=list) @dataclass class LabelResult: """Output of stage 3: a label picked from a (confirmed) schema.""" label: str reasoning: str raw: str valid: bool # whether label is in LABELS def _extract_json_object(text: str) -> dict | None: """Find the largest balanced JSON object in a blob of text.""" try: return json.loads(text.strip()) except Exception: pass starts = [i for i, c in enumerate(text) if c == "{"] for s in starts: depth = 0 for i in range(s, len(text)): if text[i] == "{": depth += 1 elif text[i] == "}": depth -= 1 if depth == 0: cand = text[s:i+1] try: return json.loads(cand) except Exception: break matches = re.findall(r"\{[^{}]*\}", text, re.DOTALL) for m in sorted(matches, key=len, reverse=True): try: return json.loads(m) except Exception: continue return None def _post_chat(endpoint: str, messages: list[dict], model: str, api_key: str, timeout: float, max_tokens: int) -> str: url = endpoint.rstrip("/") + "/chat/completions" headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} body = {"model": model, "messages": messages, "temperature": 0.0, "max_tokens": max_tokens} with httpx.Client(timeout=timeout) as client: r = client.post(url, json=body, headers=headers) r.raise_for_status() data = r.json() return data["choices"][0]["message"]["content"] # ───────────────────────────────────────────────────────────────────── # Stage 1 — extract schema only # ───────────────────────────────────────────────────────────────────── def extract_schema( study_context: str, preview: str, endpoint: str, model: str = "statlens", api_key: str = "dummy", timeout: float = 120.0, ) -> ExtractResult: """Call LLM, parse the structured SchemaSummary. No label here.""" messages = build_extract_messages(study_context, preview) raw = _post_chat(endpoint, messages, model, api_key, timeout, max_tokens=1200) obj = _extract_json_object(raw) if obj is None: return ExtractResult( schema=default_schema(), raw=raw, schema_warnings=["LLM output was not valid JSON; using defaults"], ) # Some models nest the schema under a "schema" key — accept either. if isinstance(obj.get("schema"), dict): obj = obj["schema"] schema, warns = coerce_and_validate(obj) return ExtractResult(schema=schema, raw=raw, schema_warnings=warns) # ───────────────────────────────────────────────────────────────────── # Stage 3 — pick label given a (confirmed) schema # ───────────────────────────────────────────────────────────────────── def pick_label( schema: dict, study_context: str, endpoint: str, model: str = "statlens", api_key: str = "dummy", timeout: float = 60.0, ) -> LabelResult: """Call LLM with the schema rendered as natural-language bullets.""" messages = build_label_messages(study_context, schema) raw = _post_chat(endpoint, messages, model, api_key, timeout, max_tokens=400) obj = _extract_json_object(raw) if obj is None: return LabelResult( label="none_of_these", reasoning="(failed to parse JSON from model)", raw=raw, valid=False, ) label = str(obj.get("label", "")).strip() reasoning = str(obj.get("reasoning", "")).strip() valid = label in LABELS if not valid: # Normalize case-insensitive matches before giving up. for cand in LABELS: if cand.lower() == label.lower(): label = cand valid = True break return LabelResult(label=label, reasoning=reasoning, raw=raw, valid=valid) if __name__ == "__main__": import argparse from pathlib import Path from .raw_preview import build_raw_preview ap = argparse.ArgumentParser() ap.add_argument("--context", required=True) ap.add_argument("--tsv", required=True) ap.add_argument("--endpoint", required=True) ap.add_argument("--model", default="statlens") ap.add_argument("--api-key", default="dummy") args = ap.parse_args() ctx = Path(args.context).read_text() preview = build_raw_preview(Path(args.tsv)) print("=== Stage 1: extract schema ===") er = extract_schema(ctx, preview, args.endpoint, args.model, args.api_key) print(f" warnings : {er.schema_warnings}") print(json.dumps(er.schema, indent=2)) print("\n=== Stage 3: pick label (using LLM-extracted schema verbatim) ===") lr = pick_label(er.schema, ctx, args.endpoint, args.model, args.api_key) print(f" label : {lr.label} (valid={lr.valid})") print(f" reasoning: {lr.reasoning}")