statLens / src /statlens /classifier.py
domizzz2025's picture
sync: drop 0.1.6 stale comments from source mirror
d7d0a43 verified
"""
classifier.py β€” two LLM calls for the two-stage flow.
extract_schema(study_context, preview) β†’ SchemaSummary (no label)
pick_label (schema, study_context) β†’ (label, reasoning)
The orchestrator (statlens_run.py) drives both. Each call is independent β€”
the user can edit the schema between them, and the label decision uses the
edited schema.
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass, field
import httpx
from .prompts import (
LABELS,
build_extract_messages,
build_label_messages,
)
from .schema_spec import coerce_and_validate, default_schema
@dataclass
class ExtractResult:
"""Output of stage 1: just the schema, no label."""
schema: dict
raw: str
schema_warnings: list[str] = field(default_factory=list)
@dataclass
class LabelResult:
"""Output of stage 3: a label picked from a (confirmed) schema."""
label: str
reasoning: str
raw: str
valid: bool # whether label is in LABELS
def _extract_json_object(text: str) -> dict | None:
"""Find the largest balanced JSON object in a blob of text."""
try:
return json.loads(text.strip())
except Exception:
pass
starts = [i for i, c in enumerate(text) if c == "{"]
for s in starts:
depth = 0
for i in range(s, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
cand = text[s:i+1]
try:
return json.loads(cand)
except Exception:
break
matches = re.findall(r"\{[^{}]*\}", text, re.DOTALL)
for m in sorted(matches, key=len, reverse=True):
try:
return json.loads(m)
except Exception:
continue
return None
def _post_chat(endpoint: str, messages: list[dict], model: str,
api_key: str, timeout: float, max_tokens: int) -> str:
url = endpoint.rstrip("/") + "/chat/completions"
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
body = {"model": model, "messages": messages,
"temperature": 0.0, "max_tokens": max_tokens}
with httpx.Client(timeout=timeout) as client:
r = client.post(url, json=body, headers=headers)
r.raise_for_status()
data = r.json()
return data["choices"][0]["message"]["content"]
# ─────────────────────────────────────────────────────────────────────
# Stage 1 β€” extract schema only
# ─────────────────────────────────────────────────────────────────────
def extract_schema(
study_context: str,
preview: str,
endpoint: str,
model: str = "statlens",
api_key: str = "dummy",
timeout: float = 120.0,
) -> ExtractResult:
"""Call LLM, parse the structured SchemaSummary. No label here."""
messages = build_extract_messages(study_context, preview)
raw = _post_chat(endpoint, messages, model, api_key, timeout, max_tokens=1200)
obj = _extract_json_object(raw)
if obj is None:
return ExtractResult(
schema=default_schema(), raw=raw,
schema_warnings=["LLM output was not valid JSON; using defaults"],
)
# Some models nest the schema under a "schema" key β€” accept either.
if isinstance(obj.get("schema"), dict):
obj = obj["schema"]
schema, warns = coerce_and_validate(obj)
return ExtractResult(schema=schema, raw=raw, schema_warnings=warns)
# ─────────────────────────────────────────────────────────────────────
# Stage 3 β€” pick label given a (confirmed) schema
# ─────────────────────────────────────────────────────────────────────
def pick_label(
schema: dict,
study_context: str,
endpoint: str,
model: str = "statlens",
api_key: str = "dummy",
timeout: float = 60.0,
) -> LabelResult:
"""Call LLM with the schema rendered as natural-language bullets."""
messages = build_label_messages(study_context, schema)
raw = _post_chat(endpoint, messages, model, api_key, timeout, max_tokens=400)
obj = _extract_json_object(raw)
if obj is None:
return LabelResult(
label="none_of_these",
reasoning="(failed to parse JSON from model)",
raw=raw, valid=False,
)
label = str(obj.get("label", "")).strip()
reasoning = str(obj.get("reasoning", "")).strip()
valid = label in LABELS
if not valid:
# Normalize case-insensitive matches before giving up.
for cand in LABELS:
if cand.lower() == label.lower():
label = cand
valid = True
break
return LabelResult(label=label, reasoning=reasoning, raw=raw, valid=valid)
if __name__ == "__main__":
import argparse
from pathlib import Path
from .raw_preview import build_raw_preview
ap = argparse.ArgumentParser()
ap.add_argument("--context", required=True)
ap.add_argument("--tsv", required=True)
ap.add_argument("--endpoint", required=True)
ap.add_argument("--model", default="statlens")
ap.add_argument("--api-key", default="dummy")
args = ap.parse_args()
ctx = Path(args.context).read_text()
preview = build_raw_preview(Path(args.tsv))
print("=== Stage 1: extract schema ===")
er = extract_schema(ctx, preview, args.endpoint, args.model, args.api_key)
print(f" warnings : {er.schema_warnings}")
print(json.dumps(er.schema, indent=2))
print("\n=== Stage 3: pick label (using LLM-extracted schema verbatim) ===")
lr = pick_label(er.schema, ctx, args.endpoint, args.model, args.api_key)
print(f" label : {lr.label} (valid={lr.valid})")
print(f" reasoning: {lr.reasoning}")