goon / agent /analysis.py
Binx
Initial commit: analysis app, deployment config, UI improvements
da605e9
"""
Goon analysis agent β€” all modules in one file.
Sections:
1. Response formatter
2. Data inspection & sampling
3. Question router
4. Image analysis
5. Text pattern extraction
6. Analysis execution (count, trend, stats, search, word freq, compare)
7. Core agent loop
"""
from __future__ import annotations
# ── stdlib ─────────────────────────────────────────────────────────────────
import base64
import glob
import io
import json
import os
import re
import traceback as _traceback
import urllib.request
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
# ── third-party ────────────────────────────────────────────────────────────
import anthropic
import openai # Together AI uses an OpenAI-compatible endpoint
import pandas as pd
import plotly.express as px
import plotly.io as pio
import pyarrow.dataset as ds
import pyarrow.parquet as pq
from pyarrow.compute import field
from sklearn.metrics import cohen_kappa_score
# ══════════════════════════════════════════════════════════════════════════════
# 1. Response formatter
# ══════════════════════════════════════════════════════════════════════════════
def format_result(result: dict, answer_text: str = "") -> str:
"""Combine Claude's prose answer with the structured analysis result as markdown."""
lines = []
if answer_text:
lines.append("## Answer\n")
lines.append(answer_text.strip())
lines.append("")
dataset = result.get("dataset", "")
if dataset:
lines.append("## What was analysed\n")
lines.append(f"- Dataset: `{dataset}`")
if result.get("subreddit_filter"):
lines.append(f"- Subreddit filter: `{result['subreddit_filter']}`")
if result.get("group_col"):
lines.append(f"- Grouped by: `{result['group_col']}`")
if result.get("value_col"):
lines.append(f"- Value column: `{result['value_col']}`")
lines.append("")
table = result.get("table")
if table:
lines.append("## Results\n")
lines.append(_dict_list_to_md_table(table))
lines.append("")
saved = [result[k] for k in ("saved_csv", "saved_png") if result.get(k)]
if saved:
lines.append("## Saved outputs\n")
for s in saved:
lines.append(f"- `{s}`")
lines.append("")
return "\n".join(lines)
def _dict_list_to_md_table(records: list[dict]) -> str:
if not records:
return "_No results._"
headers = list(records[0].keys())
rows = [[str(r.get(h, "")) for h in headers] for r in records]
widths = [max(len(h), max((len(r[i]) for r in rows), default=0)) for i, h in enumerate(headers)]
sep = "| " + " | ".join("-" * w for w in widths) + " |"
header_row = "| " + " | ".join(h.ljust(widths[i]) for i, h in enumerate(headers)) + " |"
data_rows = [
"| " + " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)) + " |"
for row in rows[:50]
]
return "\n".join([header_row, sep] + data_rows)
# ══════════════════════════════════════════════════════════════════════════════
# 2. Data inspection & sampling
# ══════════════════════════════════════════════════════════════════════════════
DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
OUTPUTS_DIR = Path(__file__).parent / "outputs"
OUTPUTS_DIR.mkdir(exist_ok=True)
METADATA_CACHE = OUTPUTS_DIR / "dataset_metadata.json"
def _best(name: str) -> Path:
"""Prefer the full rebuilt parquet over the original, but validate it first."""
full = DATA_DIR / f"{name}_full.parquet"
orig = DATA_DIR / f"{name}.parquet"
if full.exists():
try:
pq.read_schema(full)
return full
except Exception:
pass
return orig
DATASETS = {
"posts": _best("posts"),
"comments": _best("comments"),
"corpus_clean": DATA_DIR / "corpus_clean.parquet",
"titles": _best("titles"),
}
def _dataset_path(name: str) -> Path:
path = DATASETS.get(name)
if path is None or not path.exists():
raise FileNotFoundError(f"Dataset '{name}' not found at {path}")
return path
def _load(name: str, columns: list[str] | None = None) -> pd.DataFrame:
return pd.read_parquet(_dataset_path(name), columns=columns)
def _scanner(name: str, columns: list[str] | None = None, filters: dict | None = None) -> ds.Scanner:
path = _dataset_path(name)
dataset = ds.dataset(path, format="parquet")
expression = None
for col, value in (filters or {}).items():
if col not in dataset.schema.names or value in (None, ""):
continue
clause = field(col) == value
expression = clause if expression is None else expression & clause
return dataset.scanner(columns=columns, filter=expression)
def _read_distinct_values(name: str, column: str, limit: int = 200) -> list[str] | None:
if column not in _schema_names(name):
return None
table = _scanner(name, columns=[column]).to_table()
values = table.column(column).drop_null().unique().to_pylist()
return sorted(str(v) for v in values)[:limit]
def _read_date_range(name: str) -> dict | None:
if "created_utc" not in _schema_names(name):
return None
table = _scanner(name, columns=["created_utc"]).to_table()
if table.num_rows == 0:
return None
series = table.column("created_utc").to_pandas().dropna()
if series.empty:
return None
return {
"earliest": datetime.fromtimestamp(series.min(), tz=timezone.utc).strftime("%Y-%m-%d"),
"latest": datetime.fromtimestamp(series.max(), tz=timezone.utc).strftime("%Y-%m-%d"),
}
def _schema_names(name: str) -> list[str]:
return pq.read_schema(_dataset_path(name)).names
def compute_dataset_metadata() -> dict:
result = {}
for name, path in DATASETS.items():
if not path.exists():
result[name] = {"available": False}
continue
parquet = pq.ParquetFile(path)
schema = parquet.schema_arrow
result[name] = {
"available": True,
"path": str(path),
"rows": parquet.metadata.num_rows,
"columns": {f.name: str(f.type) for f in schema},
"subreddits": _read_distinct_values(name, "subreddit"),
"date_range": _read_date_range(name),
"metadata_cached_at": datetime.now(timezone.utc).isoformat(),
}
METADATA_CACHE.write_text(json.dumps(result, indent=2))
return result
def get_dataset_metadata(refresh: bool = False) -> dict:
if METADATA_CACHE.exists() and not refresh:
return json.loads(METADATA_CACHE.read_text())
return compute_dataset_metadata()
def list_datasets(refresh: bool = False) -> dict:
"""Return cached dataset metadata instead of loading full tables."""
return get_dataset_metadata(refresh=refresh)
def sample_rows(
dataset: str,
n: int = 5,
filters: dict | None = None,
columns: list[str] | None = None,
) -> dict:
"""Return a small deterministic preview of rows from a dataset, optionally filtered."""
selected_columns = columns or _schema_names(dataset)
table = _scanner(dataset, columns=selected_columns, filters=filters).head(n)
df = table.to_pandas() if table.num_rows else pd.DataFrame(columns=selected_columns)
return {
"dataset": dataset,
"filters": filters or {},
"n_returned": len(df),
"rows": df.fillna("").to_dict(orient="records"),
}
# ══════════════════════════════════════════════════════════════════════════════
# 3. Question router
# ══════════════════════════════════════════════════════════════════════════════
@dataclass(frozen=True)
class RoutePlan:
mode: str
allowed_tools: list[str]
guidance: str
ALL_TOOL_NAMES = [
"list_datasets", "sample_rows", "count_by_group", "trend_over_time",
"summary_stats", "top_posts", "text_search", "word_freq", "compare_groups",
"extract_frequency_patterns", "extract_dominance_patterns", "analyze_image_sample",
"export_reliability_sample", "compute_reliability",
]
def route_question(question: str) -> RoutePlan:
q = question.lower()
if any(t in q for t in ["image", "images", "photo", "photos", "visual", "depicted"]):
return RoutePlan(
mode="image",
allowed_tools=["list_datasets", "sample_rows", "analyze_image_sample", "export_reliability_sample", "compute_reliability"],
guidance="This is a visual-content question. Prefer image analysis tools and avoid text-only proxies. Always provide a coding_scheme.",
)
if any(t in q for t in ["reliability", "kappa", "human coding", "inter-rater", "validate"]):
return RoutePlan(
mode="reliability",
allowed_tools=["export_reliability_sample", "compute_reliability"],
guidance="This is a reliability/validation question. Use export_reliability_sample then compute_reliability.",
)
if any(t in q for t in ["how often", "how long", "times per", "every day", "session length", "streak"]):
return RoutePlan(
mode="pattern_frequency",
allowed_tools=["list_datasets", "sample_rows", "extract_frequency_patterns", "text_search"],
guidance="This is a behavioral frequency/duration question. Prefer regex pattern extraction over generic word counts.",
)
if any(t in q for t in ["dominant", "subordinate", "mistress", "goddess", "femdom", "submissive"]):
return RoutePlan(
mode="pattern_dominance",
allowed_tools=["list_datasets", "sample_rows", "extract_dominance_patterns", "text_search", "analyze_image_sample"],
guidance="This is a dominance/subordination framing question. Use the text pattern tool unless the user explicitly asks about images.",
)
if any(t in q for t in ["over time", "trend", "changed", "change over time", "monthly", "yearly"]):
return RoutePlan(
mode="trend",
allowed_tools=["list_datasets", "sample_rows", "trend_over_time", "count_by_group"],
guidance="This is a time-series question. Prefer trend_over_time and only use grouping/count tools to contextualize it.",
)
if any(t in q for t in ["compare", "difference", "versus", "vs", "higher", "lower"]):
return RoutePlan(
mode="compare",
allowed_tools=["list_datasets", "sample_rows", "compare_groups", "summary_stats", "count_by_group"],
guidance="This is a comparison question. Prefer compare_groups or summary_stats with explicit filters.",
)
if any(t in q for t in ["top", "highest", "best scoring", "most upvoted"]):
return RoutePlan(
mode="ranking",
allowed_tools=["list_datasets", "sample_rows", "top_posts", "summary_stats"],
guidance="This is a ranking question. Prefer top_posts and use summary_stats only if it supports the answer.",
)
if any(t in q for t in ["search", "find", "mention", "contains", "where people say"]):
return RoutePlan(
mode="search",
allowed_tools=["list_datasets", "sample_rows", "text_search", "top_posts"],
guidance="This is a retrieval question. Prefer text_search with the right dataset and text column.",
)
if any(t in q for t in ["common words", "most common words", "word frequency", "tokens"]):
return RoutePlan(
mode="lexical",
allowed_tools=["list_datasets", "sample_rows", "word_freq", "text_search"],
guidance="This is a lexical summary question. Prefer word_freq and inspect text samples only if needed.",
)
if any(t in q for t in ["how many", "count", "number of", "what proportion"]):
return RoutePlan(
mode="describe",
allowed_tools=["list_datasets", "sample_rows", "count_by_group", "summary_stats", "trend_over_time"],
guidance="This is a descriptive count question. Prefer count_by_group or summary_stats and keep the plan minimal.",
)
return RoutePlan(
mode="unknown",
allowed_tools=ALL_TOOL_NAMES,
guidance="Question type is ambiguous. Inspect metadata first, then choose the minimum reliable tool path.",
)
# ══════════════════════════════════════════════════════════════════════════════
# 4. Image analysis
# ══════════════════════════════════════════════════════════════════════════════
VISION_MODEL = "Qwen/Qwen2-VL-72B-Instruct"
TOGETHER_BASE_URL = "https://api.together.xyz/v1"
DIRECT_IMAGE_DOMAINS = {"i.redd.it", "i.imgur.com", "i.redgifs.com"}
def _load_image_urls(subreddit: str | None = None, n: int = 50) -> pd.DataFrame:
pattern = str(DATA_DIR / "*_submissions_*.csv")
files = sorted(glob.glob(pattern))
if subreddit:
files = [f for f in files if Path(f).name.lower().startswith(subreddit.lower())]
needed_cols = ["subreddit", "title", "url", "domain", "score", "is_self"]
frames = []
for f in files:
try:
df = pd.read_csv(f, usecols=lambda c: c in needed_cols, low_memory=False)
if "is_self" in df.columns:
df = df[df["is_self"] == False]
if "url" in df.columns and "domain" in df.columns:
df = df[df["domain"].isin(DIRECT_IMAGE_DOMAINS)].dropna(subset=["url"])
frames.append(df[["subreddit", "title", "url", "domain", "score"]])
except Exception:
continue
if not frames:
return pd.DataFrame()
combined = pd.concat(frames, ignore_index=True)
if len(combined) > n * 10:
combined = combined.sample(min(n * 10, len(combined)), random_state=42)
return combined.head(n * 10)
def _fetch_image_b64(url: str, timeout: int = 8) -> tuple[str, str] | None:
try:
url.encode("ascii")
except UnicodeEncodeError:
return None
try:
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0 (research bot)"})
with urllib.request.urlopen(req, timeout=timeout) as resp:
content_type = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
if not content_type.startswith("image/"):
return None
data = resp.read()
if len(data) < 1000:
return None
return base64.standard_b64encode(data).decode("utf-8"), content_type
except Exception:
return None
def analyze_image_sample(
question: str,
subreddit: str | None = None,
n_sample: int = 100,
coding_scheme: dict | None = None,
) -> dict:
"""
Sample image posts, fetch them, and ask Qwen2-VL a structured content-analysis question.
Uses Together AI (no content filters). n_sample is uncapped β€” set as needed.
"""
client = openai.OpenAI(
api_key=os.environ["TOGETHER_API_KEY"],
base_url=TOGETHER_BASE_URL,
)
candidates = _load_image_urls(subreddit=subreddit, n=n_sample * 5)
if candidates.empty:
return {
"analysis": "analyze_image_sample",
"error": "No direct image URLs found in raw CSVs for the given filters.",
"subreddit_filter": subreddit,
}
if coding_scheme:
scheme_text = "\n".join(f"- {k}: {v}" for k, v in coding_scheme.items())
prompt = (
f"{question}\n\nCoding scheme:\n{scheme_text}\n\n"
"Reply with ONLY the label and a one-sentence justification, "
"formatted as: LABEL | justification"
)
else:
prompt = (
f"{question}\n\n"
"Reply with a short structured answer. "
"If you cannot determine this from the image, reply: UNCLEAR | reason"
)
results = []
attempted = 0
for _, row in candidates.iterrows():
if len(results) >= n_sample:
break
attempted += 1
img = _fetch_image_b64(row["url"])
if img is None:
continue
b64data, media_type = img
try:
response = client.chat.completions.create(
model=VISION_MODEL,
max_tokens=200,
messages=[{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{b64data}"}},
{"type": "text", "text": prompt},
],
}],
)
answer = response.choices[0].message.content.strip()
except Exception as e:
answer = f"ERROR | {e}"
parts = answer.split("|", 1)
label = parts[0].strip().upper() if parts else "UNCLEAR"
justification = parts[1].strip() if len(parts) > 1 else ""
def _ascii_safe(s: str) -> str:
return s.encode("ascii", errors="replace").decode("ascii")
results.append({
"subreddit": _ascii_safe(str(row.get("subreddit", ""))),
"title": _ascii_safe(str(row.get("title", ""))),
"url": row["url"],
"label": label,
"justification": _ascii_safe(justification),
"score": row.get("score", None),
})
label_counts: dict[str, int] = {}
for r in results:
label_counts[r["label"]] = label_counts.get(r["label"], 0) + 1
total_coded = len(results)
saved_csv = None
if results:
out_df = pd.DataFrame(results)
stem = f"image_analysis_{subreddit or 'all'}"
saved_csv = str(OUTPUTS_DIR / f"{stem}.csv")
out_df.to_csv(saved_csv, index=False)
return {
"analysis": "analyze_image_sample",
"question": question,
"subreddit_filter": subreddit,
"n_attempted": attempted,
"n_successfully_coded": total_coded,
"label_counts": label_counts,
"label_pct": {k: round(v / total_coded * 100, 1) for k, v in label_counts.items()} if total_coded else {},
"per_image_results": results,
"saved_csv": saved_csv,
"caveats": [
"Sample limited to direct-image domains (i.redd.it, i.imgur.com, i.redgifs.com) β€” galleries and videos excluded.",
f"Vision model: {VISION_MODEL} via Together AI.",
"Coded by a single model β€” validate with human reliability sample before reporting.",
],
}
def export_reliability_sample(
source_csv: str | None = None,
n: int = 200,
random_state: int = 42,
) -> dict:
"""
Draw a stratified random sample of n images from a completed image_analysis CSV
for human coding. Saves a CSV with an empty human_label column.
"""
if source_csv is None:
# Default to most recently written image_analysis file
candidates = sorted(OUTPUTS_DIR.glob("image_analysis_*.csv"))
if not candidates:
return {"error": "No image_analysis CSV found in outputs/. Run analyze_image_sample first."}
source_csv = str(candidates[-1])
df = pd.read_csv(source_csv)
df = df[df["label"].notna() & ~df["label"].str.startswith("ERROR")]
# Stratified sample by label
sampled = (
df.groupby("label", group_keys=False)
.apply(lambda g: g.sample(min(len(g), max(1, int(n * len(g) / len(df)))), random_state=random_state))
)
# Top up to exactly n if rounding left us short
if len(sampled) < n and len(df) >= n:
remaining = df[~df.index.isin(sampled.index)]
top_up = remaining.sample(n - len(sampled), random_state=random_state)
sampled = pd.concat([sampled, top_up])
sampled = sampled.sample(frac=1, random_state=random_state).reset_index(drop=True) # shuffle
sampled.insert(0, "image_id", range(1, len(sampled) + 1))
sampled = sampled.rename(columns={"label": "model_label", "justification": "model_justification"})
sampled["human_label"] = ""
out_cols = ["image_id", "url", "title", "subreddit", "model_label", "model_justification", "human_label"]
out_cols = [c for c in out_cols if c in sampled.columns]
out_path = str(OUTPUTS_DIR / "reliability_sample.csv")
sampled[out_cols].to_csv(out_path, index=False)
return {
"analysis": "export_reliability_sample",
"source_csv": source_csv,
"n_exported": len(sampled),
"label_distribution": sampled["model_label"].value_counts().to_dict(),
"saved_csv": out_path,
"next_step": "Fill in the human_label column, then run compute_reliability.",
}
def compute_reliability(human_csv_path: str | None = None) -> dict:
"""
Compute Cohen's kappa between model_label and human_label columns
in a completed reliability sample CSV.
"""
if human_csv_path is None:
human_csv_path = str(OUTPUTS_DIR / "reliability_sample.csv")
df = pd.read_csv(human_csv_path)
df = df[df["human_label"].notna() & (df["human_label"].astype(str).str.strip() != "")]
if len(df) < 2:
return {"error": "Not enough human-coded rows. Fill in human_label column first."}
model = df["model_label"].astype(str).str.strip().str.upper()
human = df["human_label"].astype(str).str.strip().str.upper()
kappa = cohen_kappa_score(human, model)
pct_agreement = round((human == model).mean() * 100, 1)
per_label = {}
for label in sorted(human.unique()):
h = (human == label)
m = (model == label)
tp = int((h & m).sum())
fp = int((~h & m).sum())
fn = int((h & ~m).sum())
per_label[label] = {"human_n": int(h.sum()), "model_n": int(m.sum()),
"exact_matches": tp, "false_positives": fp, "false_negatives": fn}
report = {
"analysis": "compute_reliability",
"n_coded": len(df),
"cohens_kappa": round(kappa, 3),
"percent_agreement": pct_agreement,
"interpretation": (
"excellent (ΞΊ β‰₯ 0.80)" if kappa >= 0.80 else
"substantial (ΞΊ 0.60–0.79)" if kappa >= 0.60 else
"moderate (ΞΊ 0.40–0.59)" if kappa >= 0.40 else
"fair (ΞΊ 0.20–0.39)" if kappa >= 0.20 else
"poor (ΞΊ < 0.20)"
),
"per_label": per_label,
}
out_path = str(OUTPUTS_DIR / "reliability_report.json")
Path(out_path).write_text(json.dumps(report, indent=2))
report["saved_json"] = out_path
return report
# ══════════════════════════════════════════════════════════════════════════════
# 5. Text pattern extraction
# ══════════════════════════════════════════════════════════════════════════════
FREQUENCY_PATTERNS = {
"times_per_day": [
r"\b(\d+)\s*(?:times?|x)\s*(?:a|per)\s*day\b",
r"\b(\d+)\s*(?:times?|x)\s*daily\b",
],
"times_per_week": [
r"\b(\d+)\s*(?:times?|x)\s*(?:a|per)\s*week\b",
r"\b(\d+)\s*(?:times?|x)\s*weekly\b",
],
"hours_per_session": [
r"\b(\d+(?:\.\d+)?)\s*(?:hours?|hrs?)\b",
r"\b(\d+(?:\.\d+)?)\s*(?:hours?|hrs?)\s*(?:session|goon|long|straight|solid|non.?stop)\b",
],
"all_day": [
r"\ball\s*day\b", r"\ball\s*night\b", r"\ball\s*weekend\b", r"\bfor\s*hours\b",
],
"daily_habit": [
r"\bevery\s*day\b", r"\bevery\s*night\b", r"\bdaily\b", r"\bmost\s*days?\b",
],
"streak_days": [
r"\b(\d+)\s*(?:days?\s*(?:in\s*a\s*row|straight|streak|running))\b",
r"\b(\d+)\s*(?:-|–)?\s*day\s*(?:streak|binge)\b",
],
}
DOMINANCE_PATTERNS = {
"dominant_language": [
r"\bdominat(?:e|es|ed|ing|ion|rix|rix)\b", r"\bfem(?:dom|domme)\b",
r"\bmistress\b", r"\bgoddess\b", r"\bqueen\b", r"\bowner\b",
r"\balpha\b", r"\bin\s*control\b", r"\bboss\b",
],
"subordinate_language": [
r"\bsubmiss(?:ive|ion)\b", r"\bsub\b", r"\bobedient\b", r"\bslave\b",
r"\bpet\b", r"\bslut\b", r"\bwhore\b", r"\bused\b",
r"\bcontrolled\b", r"\bworshiped?\b", r"\bworship(?:ped|ing)\b",
],
"neutral_object": [
r"\bperfect\b", r"\bbeautiful\b", r"\bhot\b", r"\bsexy\b", r"\bstunning\b",
],
}
def _compile(patterns: list[str]) -> re.Pattern:
return re.compile("|".join(patterns), re.IGNORECASE)
def extract_frequency_patterns(
dataset: str = "comments",
text_col: str = "body",
subreddit: str | None = None,
n_examples: int = 5,
sample_size: int = 5_000_000,
) -> dict:
"""Mine text for frequency and duration language across the full dataset."""
cols = [text_col] + (["subreddit"] if subreddit else [])
df = _scanner(
dataset, columns=cols,
filters={"subreddit": subreddit} if subreddit else None,
).head(sample_size).to_pandas()
text = df[text_col].fillna("")
total_docs = len(text)
results = {}
for category, pats in FREQUENCY_PATTERNS.items():
regex = _compile(pats)
matches_mask = text.str.contains(regex.pattern, regex=True, na=False)
hit_texts = text[matches_mask]
values = []
for pat in pats:
r = re.compile(pat, re.IGNORECASE)
for t in hit_texts:
for m in r.finditer(t):
if m.groups():
try:
values.append(float(m.group(1)))
except (IndexError, ValueError):
pass
raw_examples = hit_texts.sample(min(n_examples, len(hit_texts)), random_state=42).tolist() if len(hit_texts) > 0 else []
results[category] = {
"count": int(matches_mask.sum()),
"pct_of_docs": round(matches_mask.mean() * 100, 3),
"numeric_values": sorted(values)[:50] if values else [],
"mean_value": round(sum(values) / len(values), 2) if values else None,
"examples": [t.encode("ascii", errors="replace").decode("ascii") for t in raw_examples],
}
return {
"analysis": "extract_frequency_patterns",
"dataset": dataset,
"text_col": text_col,
"subreddit_filter": subreddit,
"total_docs_sampled": total_docs,
"patterns": results,
}
def extract_dominance_patterns(
dataset: str = "comments",
text_col: str = "body",
subreddit: str | None = None,
sample_size: int = 5_000_000,
) -> dict:
"""Count dominant, subordinate, and neutral language in text."""
cols = [text_col] + (["subreddit"] if subreddit else [])
df = _scanner(
dataset, columns=cols,
filters={"subreddit": subreddit} if subreddit else None,
).head(sample_size).to_pandas()
text = df[text_col].fillna("")
total_docs = len(text)
results = {}
for category, pats in DOMINANCE_PATTERNS.items():
regex = _compile(pats)
mask = text.str.contains(regex, na=False)
hits = text[mask]
raw_examples = hits.sample(min(5, len(hits)), random_state=42).tolist() if len(hits) > 0 else []
results[category] = {
"count": int(mask.sum()),
"pct_of_docs": round(mask.mean() * 100, 3),
"examples": [t.encode("ascii", errors="replace").decode("ascii") for t in raw_examples],
}
dom = results.get("dominant_language", {}).get("count", 0)
sub = results.get("subordinate_language", {}).get("count", 0)
total = dom + sub
ratio = {
"dominant_pct": round(dom / total * 100, 1) if total else None,
"subordinate_pct": round(sub / total * 100, 1) if total else None,
"interpretation": (
"More subordinate language" if sub > dom else
"More dominant language" if dom > sub else
"Roughly balanced"
) if total else "No data",
}
return {
"analysis": "extract_dominance_patterns",
"dataset": dataset,
"text_col": text_col,
"subreddit_filter": subreddit,
"total_docs_sampled": total_docs,
"categories": results,
"dominance_ratio": ratio,
"caveat": (
"This analysis counts language patterns in text, not visual image content. "
"It reflects how women are described in text, not how they appear in images. "
"For image-based analysis use analyze_image_sample."
),
}
# ══════════════════════════════════════════════════════════════════════════════
# 6. Analysis execution
# ══════════════════════════════════════════════════════════════════════════════
def _ts_to_date(series: pd.Series) -> pd.Series:
return pd.to_datetime(series, unit="s", utc=True)
def _normalize_filters(
filters: dict | None = None,
subreddit: str | None = None,
date_from: str | None = None,
date_to: str | None = None,
min_score: float | None = None,
) -> dict:
merged = dict(filters or {})
if subreddit:
merged["subreddit"] = subreddit
if date_from:
merged["date_from"] = date_from
if date_to:
merged["date_to"] = date_to
if min_score is not None:
merged["min_score"] = min_score
return merged
def _apply_filters(df: pd.DataFrame, filters: dict | None = None) -> pd.DataFrame:
if not filters:
return df
filtered = df
if filters.get("subreddit") and "subreddit" in filtered.columns:
filtered = filtered[filtered["subreddit"] == filters["subreddit"]]
if filters.get("author") and "author" in filtered.columns:
filtered = filtered[filtered["author"] == filters["author"]]
if filters.get("min_score") is not None and "score" in filtered.columns:
filtered["score"] = pd.to_numeric(filtered["score"], errors="coerce")
filtered = filtered[filtered["score"] >= filters["min_score"]]
if ("date_from" in filters or "date_to" in filters) and "created_utc" in filtered.columns:
created = _ts_to_date(filtered["created_utc"])
if filters.get("date_from"):
filtered = filtered[created >= pd.Timestamp(filters["date_from"], tz="UTC")]
created = _ts_to_date(filtered["created_utc"])
if filters.get("date_to"):
filtered = filtered[created <= pd.Timestamp(filters["date_to"], tz="UTC")]
return filtered
def _save_csv(df: pd.DataFrame, stem: str) -> str:
path = OUTPUTS_DIR / f"{stem}.csv"
df.to_csv(path, index=False)
return str(path)
def _save_fig(fig, stem: str) -> str:
path = OUTPUTS_DIR / f"{stem}.png"
pio.write_image(fig, str(path), scale=2)
return str(path)
def count_by_group(dataset: str, group_col: str, top_n: int = 30, filters: dict | None = None) -> dict:
"""Count rows grouped by a column. Returns sorted table + bar chart."""
filter_cols = [c for c in ["subreddit", "author", "score", "created_utc"] if c != group_col]
df = _load(dataset, columns=list(dict.fromkeys([group_col] + filter_cols)))
df = _apply_filters(df, filters)
counts = (
df.groupby(group_col, dropna=False)
.size().reset_index(name="count")
.sort_values("count", ascending=False).head(top_n)
)
stem = f"count_by_{group_col}_{dataset}"
saved_csv = _save_csv(counts, stem)
fig = px.bar(
counts.sort_values("count"), x="count", y=group_col, orientation="h",
title=f"Count by {group_col}", labels={"count": "Count", group_col: group_col},
)
fig.update_layout(yaxis={"categoryorder": "total ascending"})
try:
saved_png = _save_fig(fig, stem)
except Exception:
saved_png = None
return {
"analysis": "count_by_group", "dataset": dataset, "group_col": group_col,
"filters": filters or {}, "total_rows": len(df),
"table": counts.to_dict(orient="records"),
"saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": fig.to_json(),
}
def trend_over_time(
dataset: str, freq: str = "M", group_col: str | None = None,
top_groups: int = 8, filters: dict | None = None,
) -> dict:
"""Count posts/comments over time, optionally broken out by a grouping column."""
cols = ["created_utc"] + ([group_col] if group_col else []) + ["subreddit", "author", "score"]
cols = list(dict.fromkeys(cols))
df = _load(dataset, columns=cols)
df = _apply_filters(df, filters)
df["period"] = _ts_to_date(df["created_utc"]).dt.to_period(freq).astype(str)
if group_col:
top = df[group_col].value_counts().head(top_groups).index.tolist()
df = df[df[group_col].isin(top)]
counts = (
df.groupby(["period", group_col]).size()
.reset_index(name="count").sort_values("period")
)
fig = px.line(counts, x="period", y="count", color=group_col,
title=f"Activity over time by {group_col}")
else:
counts = (
df.groupby("period").size()
.reset_index(name="count").sort_values("period")
)
fig = px.line(counts, x="period", y="count", title="Activity over time")
stem = f"trend_{dataset}_{group_col or 'all'}_{freq}"
saved_csv = _save_csv(counts, stem)
try:
saved_png = _save_fig(fig, stem)
except Exception:
saved_png = None
return {
"analysis": "trend_over_time", "dataset": dataset, "freq": freq,
"group_col": group_col, "filters": filters or {},
"table": counts.to_dict(orient="records"),
"saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": fig.to_json(),
}
def summary_stats(
dataset: str, value_col: str, group_col: str | None = None,
top_n: int = 30, filters: dict | None = None,
) -> dict:
"""Descriptive statistics for a numeric column, optionally by group."""
cols = [value_col] + ([group_col] if group_col else []) + ["subreddit", "author", "score", "created_utc"]
cols = list(dict.fromkeys(cols))
df = _load(dataset, columns=cols)
df = _apply_filters(df, filters)
df[value_col] = pd.to_numeric(df[value_col], errors="coerce")
if group_col:
stats = (
df.groupby(group_col)[value_col]
.agg(["count", "mean", "median", "std", "min", "max"])
.reset_index().sort_values("mean", ascending=False).head(top_n).round(2)
)
else:
raw = df[value_col].describe().round(2)
stats = raw.reset_index()
stats.columns = ["stat", "value"]
stem = f"stats_{value_col}_{group_col or 'all'}_{dataset}"
saved_csv = _save_csv(stats, stem)
try:
if group_col:
fig = px.bar(stats, x=group_col, y="mean", error_y="std",
title=f"{value_col} by {group_col}",
labels={"mean": f"Mean {value_col}"})
else:
fig = px.histogram(df[value_col].dropna(), nbins=50,
title=f"Distribution of {value_col}",
labels={"value": value_col})
saved_png = _save_fig(fig, stem)
plotly_json = fig.to_json()
except Exception:
saved_png = None
plotly_json = None
return {
"analysis": "summary_stats", "dataset": dataset, "value_col": value_col,
"group_col": group_col, "filters": filters or {},
"n_total": len(df), "n_missing": int(df[value_col].isna().sum()),
"table": stats.to_dict(orient="records"),
"saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": plotly_json,
}
def top_posts(
dataset: str = "posts", n: int = 20,
subreddit: str | None = None, text_col: str = "title",
filters: dict | None = None,
) -> dict:
"""Return the highest-scoring posts, optionally filtered to a subreddit."""
filters = _normalize_filters(filters=filters, subreddit=subreddit)
cols = [c for c in ["subreddit", "author", text_col, "score", "created_utc"] if c]
df = _load(dataset, columns=cols)
df = _apply_filters(df, filters)
top = df.nlargest(n, "score")[cols].copy()
top["date"] = _ts_to_date(top["created_utc"]).dt.strftime("%Y-%m-%d")
top = top.drop(columns=["created_utc"])
stem = f"top_posts_{subreddit or 'all'}_{dataset}"
saved = _save_csv(top, stem)
return {
"analysis": "top_posts", "dataset": dataset,
"subreddit_filter": subreddit, "filters": filters, "n": n,
"table": top.fillna("").to_dict(orient="records"), "saved_csv": saved,
}
def text_search(
dataset: str, query: str, text_col: str = "body",
n: int = 20, case_sensitive: bool = False,
subreddit: str | None = None, filters: dict | None = None,
) -> dict:
"""Search for a string pattern in a text column."""
filters = _normalize_filters(filters=filters, subreddit=subreddit)
cols = [c for c in ["subreddit", "author", text_col, "score", "created_utc"] if c]
df = _load(dataset, columns=cols)
df = _apply_filters(df, filters)
mask = df[text_col].fillna("").str.contains(query, case=case_sensitive, regex=False)
hits = df[mask].nlargest(n, "score").copy()
hits["date"] = _ts_to_date(hits["created_utc"]).dt.strftime("%Y-%m-%d")
hits = hits.drop(columns=["created_utc"])
stem = f"search_{query[:30].replace(' ', '_')}_{dataset}"
saved = _save_csv(hits, stem)
return {
"analysis": "text_search", "dataset": dataset, "query": query,
"text_col": text_col, "filters": filters,
"n_matches": int(mask.sum()), "n_returned": len(hits),
"table": hits.fillna("").to_dict(orient="records"), "saved_csv": saved,
}
def word_freq(
dataset: str = "corpus_clean", text_col: str = "text_cleaned",
top_n: int = 50, subreddit: str | None = None,
min_length: int = 4, filters: dict | None = None,
) -> dict:
"""Count word frequencies in a text column."""
filters = _normalize_filters(filters=filters, subreddit=subreddit)
cols = list(dict.fromkeys([text_col] + (["subreddit"] if subreddit else []) + ["author", "score", "created_utc"]))
df = _load(dataset, columns=cols)
df = _apply_filters(df, filters)
stop = {
"the","and","for","that","with","this","you","are","was","not",
"have","from","they","will","what","been","when","your","more",
"just","about","like","there","were","would","into","than","then",
"some","also","very","only","over","back","can","out","all","but",
"one","had","has","its","which","their","time","our","who","may",
"after","other","these","those","such","each","him","her","his",
"she","how","did","being","now","way","any","too","much","even",
"get","got","got","could","should","make","made","said","still",
"here","because","really","know","think","going","reddit","post",
"comment","deleted","removed",
}
words = (
df[text_col].fillna("").str.lower()
.str.replace(r"[^a-z\s]", " ", regex=True).str.split().explode()
)
words = words[words.str.len() >= min_length]
words = words[~words.isin(stop)]
counts = words.value_counts().head(top_n).reset_index()
counts.columns = ["word", "count"]
stem = f"wordfreq_{text_col}_{subreddit or 'all'}_{dataset}"
saved_csv = _save_csv(counts, stem)
fig = px.bar(
counts.head(30).sort_values("count"), x="count", y="word", orientation="h",
title="Top words by frequency", labels={"count": "Count", "word": "Word"},
)
fig.update_layout(yaxis={"categoryorder": "total ascending"})
try:
saved_png = _save_fig(fig, stem)
except Exception:
saved_png = None
return {
"analysis": "word_freq", "dataset": dataset, "text_col": text_col,
"subreddit_filter": subreddit, "filters": filters, "total_docs": len(df),
"table": counts.to_dict(orient="records"),
"saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": fig.to_json(),
}
def compare_groups(
dataset: str, group_col: str, value_col: str,
groups: list[str] | None = None, filters: dict | None = None,
) -> dict:
"""Compare a numeric value across groups with descriptive stats."""
cols = list(dict.fromkeys([group_col, value_col, "subreddit", "author", "score", "created_utc"]))
df = _load(dataset, columns=cols)
df = _apply_filters(df, filters)
df[value_col] = pd.to_numeric(df[value_col], errors="coerce")
if groups:
df = df[df[group_col].isin(groups)]
stats = (
df.groupby(group_col)[value_col]
.agg(count="count", mean="mean", median="median", std="std")
.reset_index().sort_values("median", ascending=False).round(3)
)
stem = f"compare_{group_col}_{value_col}_{dataset}"
saved_csv = _save_csv(stats, stem)
fig = px.bar(stats, x=group_col, y="median", error_y="std",
title=f"{value_col} by {group_col} (median Β± std)",
labels={"median": f"Median {value_col}"})
try:
saved_png = _save_fig(fig, stem)
except Exception:
saved_png = None
return {
"analysis": "compare_groups", "dataset": dataset,
"group_col": group_col, "value_col": value_col,
"filters": filters or {}, "groups_compared": stats[group_col].tolist(),
"table": stats.to_dict(orient="records"),
"saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": fig.to_json(),
}
# ══════════════════════════════════════════════════════════════════════════════
# 7. Core agent loop
# ══════════════════════════════════════════════════════════════════════════════
MODEL = "claude-opus-4-6"
TOOLS = [
{
"name": "list_datasets",
"description": (
"List cached dataset metadata: paths, row counts, columns, subreddits, and date ranges. "
"Use this to inspect the available data without loading full tables."
),
"input_schema": {
"type": "object",
"properties": {"refresh": {"type": "boolean", "default": False,
"description": "Recompute metadata from source parquets instead of using the cache."}},
"required": [],
},
},
{
"name": "sample_rows",
"description": "Return a small deterministic preview of rows from a dataset, optionally filtered and column-limited.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]},
"n": {"type": "integer", "default": 5},
"filters": {"type": "object", "description": "Optional equality filters, e.g. {\"subreddit\": \"GOONED\"}"},
"columns": {"type": "array", "items": {"type": "string"}, "description": "Optional subset of columns to preview."},
},
"required": ["dataset"],
},
},
{
"name": "count_by_group",
"description": "Count rows in a dataset grouped by one column, with optional shared filters.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]},
"group_col": {"type": "string"},
"top_n": {"type": "integer", "default": 30},
"filters": {"type": "object"},
},
"required": ["dataset", "group_col"],
},
},
{
"name": "trend_over_time",
"description": "Count rows over time, optionally split by one grouping column, with optional shared filters.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]},
"freq": {"type": "string", "enum": ["D", "W", "M", "Q", "Y"], "default": "M"},
"group_col": {"type": "string"},
"top_groups": {"type": "integer", "default": 8},
"filters": {"type": "object"},
},
"required": ["dataset"],
},
},
{
"name": "summary_stats",
"description": "Descriptive statistics for a numeric column, optionally grouped and filtered.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]},
"value_col": {"type": "string"},
"group_col": {"type": "string"},
"top_n": {"type": "integer", "default": 30},
"filters": {"type": "object"},
},
"required": ["dataset", "value_col"],
},
},
{
"name": "top_posts",
"description": "Return the highest-scoring posts, optionally filtered by subreddit or shared filters.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "titles"], "default": "posts"},
"n": {"type": "integer", "default": 20},
"subreddit": {"type": "string"},
"text_col": {"type": "string", "default": "title"},
"filters": {"type": "object"},
},
"required": [],
},
},
{
"name": "text_search",
"description": "Search for a phrase in a text column and return top matching rows.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]},
"query": {"type": "string"},
"text_col": {"type": "string", "default": "body"},
"n": {"type": "integer", "default": 20},
"subreddit": {"type": "string"},
"filters": {"type": "object"},
},
"required": ["dataset", "query"],
},
},
{
"name": "word_freq",
"description": "Count word frequencies in a text column with optional shared filters.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"], "default": "corpus_clean"},
"text_col": {"type": "string", "default": "text_cleaned"},
"top_n": {"type": "integer", "default": 50},
"subreddit": {"type": "string"},
"min_length": {"type": "integer", "default": 4},
"filters": {"type": "object"},
},
"required": [],
},
},
{
"name": "compare_groups",
"description": "Compare one numeric column across groups with optional shared filters.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]},
"group_col": {"type": "string"},
"value_col": {"type": "string"},
"groups": {"type": "array", "items": {"type": "string"}},
"filters": {"type": "object"},
},
"required": ["dataset", "group_col", "value_col"],
},
},
{
"name": "extract_frequency_patterns",
"description": "Mine text for frequency and duration language across the full dataset.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean"], "default": "comments"},
"text_col": {"type": "string", "default": "body"},
"subreddit": {"type": "string"},
"n_examples": {"type": "integer", "default": 5},
"sample_size": {"type": "integer", "default": 5000000},
},
"required": [],
},
},
{
"name": "extract_dominance_patterns",
"description": "Count dominant vs subordinate language in text, not images.",
"input_schema": {
"type": "object",
"properties": {
"dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean"], "default": "comments"},
"text_col": {"type": "string", "default": "body"},
"subreddit": {"type": "string"},
"sample_size": {"type": "integer", "default": 5000000},
},
"required": [],
},
},
{
"name": "analyze_image_sample",
"description": "Run vision coding on a sample of image posts using Qwen2-VL via Together AI (no content filters). Always provide a coding_scheme for research use.",
"input_schema": {
"type": "object",
"properties": {
"question": {"type": "string"},
"subreddit": {"type": "string"},
"n_sample": {"type": "integer", "default": 100, "description": "Number of images to code. No hard cap β€” set to 500+ for large analyses."},
"coding_scheme": {"type": "object", "description": "Dict of {label: definition}. Always provide this for research questions."},
},
"required": ["question"],
},
},
{
"name": "export_reliability_sample",
"description": "Export a stratified random sample of coded images for human validation. Run after analyze_image_sample.",
"input_schema": {
"type": "object",
"properties": {
"source_csv": {"type": "string", "description": "Path to image_analysis CSV. Defaults to most recent."},
"n": {"type": "integer", "default": 200},
"random_state": {"type": "integer", "default": 42},
},
"required": [],
},
},
{
"name": "compute_reliability",
"description": "Compute Cohen's kappa between model and human codes after the human_label column has been filled in.",
"input_schema": {
"type": "object",
"properties": {
"human_csv_path": {"type": "string", "description": "Path to completed reliability_sample.csv. Defaults to outputs/reliability_sample.csv."},
},
"required": [],
},
},
]
TOOL_FN_MAP = {
"list_datasets": lambda args: list_datasets(**args),
"sample_rows": lambda args: sample_rows(**args),
"count_by_group": lambda args: count_by_group(**args),
"trend_over_time": lambda args: trend_over_time(**args),
"summary_stats": lambda args: summary_stats(**args),
"top_posts": lambda args: top_posts(**args),
"text_search": lambda args: text_search(**args),
"word_freq": lambda args: word_freq(**args),
"compare_groups": lambda args: compare_groups(**args),
"extract_frequency_patterns": lambda args: extract_frequency_patterns(**args),
"extract_dominance_patterns": lambda args: extract_dominance_patterns(**args),
"analyze_image_sample": lambda args: analyze_image_sample(**args),
"export_reliability_sample": lambda args: export_reliability_sample(**args),
"compute_reliability": lambda args: compute_reliability(**args),
}
def _safe_str(obj: object) -> object:
"""Recursively encode any non-ASCII strings as JSON-safe escaped text."""
if isinstance(obj, str):
return obj.encode("ascii", errors="backslashreplace").decode("ascii")
if isinstance(obj, dict):
return {k: _safe_str(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_safe_str(item) for item in obj]
return obj
def _compact_result(result: object) -> dict:
if not isinstance(result, dict):
return {"value": result}
compact = {}
for key in ("analysis", "dataset", "group_col", "value_col", "query", "filters",
"n_matches", "n_returned", "n_total", "groups_compared", "saved_csv", "saved_png", "error"):
if key in result and result.get(key) is not None:
compact[key] = result[key]
table = result.get("table")
if isinstance(table, list):
compact["table_preview"] = table[:3]
compact["table_rows"] = len(table)
return compact
def _conversation_state_summary(turns: list[dict] | None) -> str:
if not turns:
return "No prior analytical state."
summary = []
for idx, turn in enumerate(turns[-3:], start=1):
summary.append({
"turn": idx,
"question": _safe_str(turn.get("question", "")),
"answer": _safe_str(turn.get("answer", "")),
"tool_calls": [
{"tool": tc.get("tool"), "args": _safe_str(tc.get("args", {})),
"result": _safe_str(_compact_result(tc.get("result")))}
for tc in turn.get("tool_calls", [])
],
"artifacts": turn.get("artifacts", []),
})
return json.dumps(summary, default=str, indent=2)
def _tool_names(tools: list[dict]) -> list[str]:
return [t["name"] for t in tools]
def _tool_subset(allowed_tools: list[str]) -> list[dict]:
allowed = set(allowed_tools)
return [t for t in TOOLS if t["name"] in allowed]
def _system_prompt(route_mode: str, route_guidance: str, conversation_state: str) -> str:
metadata = get_dataset_metadata()
dataset_lines = []
for name, info in metadata.items():
if not info.get("available"):
continue
date_range = info.get("date_range") or {}
dataset_lines.append(
f"- {name}: {info.get('rows')} rows; columns={list(info.get('columns', {}).keys())}; "
f"date_range={date_range or 'n/a'}"
)
dataset_summary = "\n".join(dataset_lines)
return f"""You are a question-driven data analysis agent working over local Reddit datasets.
Available dataset metadata:
{dataset_summary}
Current route mode: {route_mode}
Route guidance: {route_guidance}
Prior analytical state:
{conversation_state}
Rules:
1. Use the route guidance and only the provided tools.
2. Inspect metadata or row previews before making assumptions when the schema is unclear.
3. Run actual tools for numbers; do not guess.
4. Prefer one minimal reproducible tool path over exploratory tool spam.
5. Distinguish direct findings from caveats.
6. If prior turns already produced a relevant result, reuse that context instead of recomputing unless the user asks for a change.
7. Answer with this structure: direct answer, what was analysed, method, caveats.
8. ALWAYS prefer tools that produce charts (trend_over_time, count_by_group, compare_groups, summary_stats, word_freq) over plain text summaries when the question is quantitative. Every numeric answer should have a chart.
9. For questions about images or visual content, use analyze_image_sample. It reads from raw CSV files with image URLs β€” no separate setup needed. ALWAYS generate an explicit coding_scheme dict (with label names as keys and definitions as values) before calling this tool β€” never leave coding_scheme null for a research question.
10. After a large image coding run, offer to run export_reliability_sample to generate a human validation set, then compute_reliability once the user has filled in the human_label column.
11. The dataset covers 30 subreddits including GOONED, GOONEDISBACK, GoonCaves, girlgooners, and more. Use subreddit filters to drill into specific communities."""
def run_agent(
question: str,
history: list[dict] | None = None,
turns: list[dict] | None = None,
analysis_context: list[dict] | None = None,
conversation_state: list[dict] | None = None,
) -> dict:
"""Run the agent for a user question with deterministic routing and structured prior state."""
try:
return _run_agent_inner(question, history, turns, analysis_context, conversation_state)
except UnicodeEncodeError as exc:
tb = _traceback.format_exc()
raise RuntimeError(
f"Unicode encoding error (non-ASCII character in data pipeline).\n\n"
f"Detail: {exc}\n\nTraceback:\n{tb}"
) from exc
def _run_agent_inner(
question: str,
history: list[dict] | None = None,
turns: list[dict] | None = None,
analysis_context: list[dict] | None = None,
conversation_state: list[dict] | None = None,
) -> dict:
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
prior_turns = turns or analysis_context or conversation_state or []
route = route_question(question)
available_tools = _tool_subset(route.allowed_tools)
safe_history = [_safe_str(msg) for msg in (history or [])]
messages = safe_history
messages.append({"role": "user", "content": _safe_str(question)})
tool_calls_log = []
plotly_jsons = []
total_input_tokens = 0
total_output_tokens = 0
system = _system_prompt(
route_mode=route.mode,
route_guidance=route.guidance,
conversation_state=_conversation_state_summary(prior_turns),
)
while True:
safe_messages = _safe_str(messages)
safe_system = _safe_str(system)
try:
response = client.messages.create(
model=MODEL,
max_tokens=4096,
system=safe_system,
tools=available_tools,
messages=safe_messages,
)
except UnicodeEncodeError:
stripped_messages = json.loads(json.dumps(safe_messages, default=str, ensure_ascii=True))
stripped_system = safe_system.encode("ascii", errors="ignore").decode("ascii")
response = client.messages.create(
model=MODEL,
max_tokens=4096,
system=stripped_system,
tools=available_tools,
messages=stripped_messages,
)
text_parts = [block.text for block in response.content if block.type == "text"]
tool_use_blocks = [block for block in response.content if block.type == "tool_use"]
total_input_tokens += getattr(response.usage, "input_tokens", 0) or 0
total_output_tokens += getattr(response.usage, "output_tokens", 0) or 0
if response.stop_reason == "end_turn" or not tool_use_blocks:
# Claude Opus 4.6 pricing: $15/M input, $75/M output
cost_usd = (total_input_tokens / 1_000_000 * 15.0) + (total_output_tokens / 1_000_000 * 75.0)
return {
"answer": "\n".join(text_parts).strip(),
"tool_calls": tool_calls_log,
"plotly_json": plotly_jsons[-1] if plotly_jsons else None,
"plotly_jsons": plotly_jsons,
"route": route.mode,
"allowed_tools": _tool_names(available_tools),
"usage": {
"input_tokens": total_input_tokens,
"output_tokens": total_output_tokens,
"cost_usd": round(cost_usd, 4),
},
}
tool_results = []
for block in tool_use_blocks:
fn = TOOL_FN_MAP.get(block.name)
if fn is None:
result = {"error": f"Unknown tool: {block.name}"}
else:
try:
result = fn(block.input)
if isinstance(result, dict) and result.get("plotly_json"):
plotly_jsons.append(result["plotly_json"])
except Exception as exc:
result = {"error": str(exc)}
safe_result = _safe_str(result)
tool_calls_log.append({"tool": block.name, "args": block.input, "result": result})
tool_results.append({
"type": "tool_result",
"tool_use_id": block.id,
"content": json.dumps(safe_result, default=str, ensure_ascii=True),
})
assistant_content = [_safe_str(block.model_dump()) for block in response.content]
messages.append({"role": "assistant", "content": assistant_content})
messages.append({"role": "user", "content": tool_results})