| """ |
| Goon analysis agent β all modules in one file. |
| |
| Sections: |
| 1. Response formatter |
| 2. Data inspection & sampling |
| 3. Question router |
| 4. Image analysis |
| 5. Text pattern extraction |
| 6. Analysis execution (count, trend, stats, search, word freq, compare) |
| 7. Core agent loop |
| """ |
|
|
| from __future__ import annotations |
|
|
| |
| import base64 |
| import glob |
| import io |
| import json |
| import os |
| import re |
| import traceback as _traceback |
| import urllib.request |
| from dataclasses import dataclass |
| from datetime import datetime, timezone |
| from pathlib import Path |
|
|
| |
| import anthropic |
| import openai |
| import pandas as pd |
| import plotly.express as px |
| import plotly.io as pio |
| import pyarrow.dataset as ds |
| import pyarrow.parquet as pq |
| from pyarrow.compute import field |
| from sklearn.metrics import cohen_kappa_score |
|
|
|
|
| |
| |
| |
|
|
| def format_result(result: dict, answer_text: str = "") -> str: |
| """Combine Claude's prose answer with the structured analysis result as markdown.""" |
| lines = [] |
|
|
| if answer_text: |
| lines.append("## Answer\n") |
| lines.append(answer_text.strip()) |
| lines.append("") |
|
|
| dataset = result.get("dataset", "") |
| if dataset: |
| lines.append("## What was analysed\n") |
| lines.append(f"- Dataset: `{dataset}`") |
| if result.get("subreddit_filter"): |
| lines.append(f"- Subreddit filter: `{result['subreddit_filter']}`") |
| if result.get("group_col"): |
| lines.append(f"- Grouped by: `{result['group_col']}`") |
| if result.get("value_col"): |
| lines.append(f"- Value column: `{result['value_col']}`") |
| lines.append("") |
|
|
| table = result.get("table") |
| if table: |
| lines.append("## Results\n") |
| lines.append(_dict_list_to_md_table(table)) |
| lines.append("") |
|
|
| saved = [result[k] for k in ("saved_csv", "saved_png") if result.get(k)] |
| if saved: |
| lines.append("## Saved outputs\n") |
| for s in saved: |
| lines.append(f"- `{s}`") |
| lines.append("") |
|
|
| return "\n".join(lines) |
|
|
|
|
| def _dict_list_to_md_table(records: list[dict]) -> str: |
| if not records: |
| return "_No results._" |
| headers = list(records[0].keys()) |
| rows = [[str(r.get(h, "")) for h in headers] for r in records] |
| widths = [max(len(h), max((len(r[i]) for r in rows), default=0)) for i, h in enumerate(headers)] |
| sep = "| " + " | ".join("-" * w for w in widths) + " |" |
| header_row = "| " + " | ".join(h.ljust(widths[i]) for i, h in enumerate(headers)) + " |" |
| data_rows = [ |
| "| " + " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)) + " |" |
| for row in rows[:50] |
| ] |
| return "\n".join([header_row, sep] + data_rows) |
|
|
|
|
| |
| |
| |
|
|
| DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data")) |
| OUTPUTS_DIR = Path(__file__).parent / "outputs" |
| OUTPUTS_DIR.mkdir(exist_ok=True) |
| METADATA_CACHE = OUTPUTS_DIR / "dataset_metadata.json" |
|
|
|
|
| def _best(name: str) -> Path: |
| """Prefer the full rebuilt parquet over the original, but validate it first.""" |
| full = DATA_DIR / f"{name}_full.parquet" |
| orig = DATA_DIR / f"{name}.parquet" |
| if full.exists(): |
| try: |
| pq.read_schema(full) |
| return full |
| except Exception: |
| pass |
| return orig |
|
|
|
|
| DATASETS = { |
| "posts": _best("posts"), |
| "comments": _best("comments"), |
| "corpus_clean": DATA_DIR / "corpus_clean.parquet", |
| "titles": _best("titles"), |
| } |
|
|
|
|
| def _dataset_path(name: str) -> Path: |
| path = DATASETS.get(name) |
| if path is None or not path.exists(): |
| raise FileNotFoundError(f"Dataset '{name}' not found at {path}") |
| return path |
|
|
|
|
| def _load(name: str, columns: list[str] | None = None) -> pd.DataFrame: |
| return pd.read_parquet(_dataset_path(name), columns=columns) |
|
|
|
|
| def _scanner(name: str, columns: list[str] | None = None, filters: dict | None = None) -> ds.Scanner: |
| path = _dataset_path(name) |
| dataset = ds.dataset(path, format="parquet") |
| expression = None |
| for col, value in (filters or {}).items(): |
| if col not in dataset.schema.names or value in (None, ""): |
| continue |
| clause = field(col) == value |
| expression = clause if expression is None else expression & clause |
| return dataset.scanner(columns=columns, filter=expression) |
|
|
|
|
| def _read_distinct_values(name: str, column: str, limit: int = 200) -> list[str] | None: |
| if column not in _schema_names(name): |
| return None |
| table = _scanner(name, columns=[column]).to_table() |
| values = table.column(column).drop_null().unique().to_pylist() |
| return sorted(str(v) for v in values)[:limit] |
|
|
|
|
| def _read_date_range(name: str) -> dict | None: |
| if "created_utc" not in _schema_names(name): |
| return None |
| table = _scanner(name, columns=["created_utc"]).to_table() |
| if table.num_rows == 0: |
| return None |
| series = table.column("created_utc").to_pandas().dropna() |
| if series.empty: |
| return None |
| return { |
| "earliest": datetime.fromtimestamp(series.min(), tz=timezone.utc).strftime("%Y-%m-%d"), |
| "latest": datetime.fromtimestamp(series.max(), tz=timezone.utc).strftime("%Y-%m-%d"), |
| } |
|
|
|
|
| def _schema_names(name: str) -> list[str]: |
| return pq.read_schema(_dataset_path(name)).names |
|
|
|
|
| def compute_dataset_metadata() -> dict: |
| result = {} |
| for name, path in DATASETS.items(): |
| if not path.exists(): |
| result[name] = {"available": False} |
| continue |
| parquet = pq.ParquetFile(path) |
| schema = parquet.schema_arrow |
| result[name] = { |
| "available": True, |
| "path": str(path), |
| "rows": parquet.metadata.num_rows, |
| "columns": {f.name: str(f.type) for f in schema}, |
| "subreddits": _read_distinct_values(name, "subreddit"), |
| "date_range": _read_date_range(name), |
| "metadata_cached_at": datetime.now(timezone.utc).isoformat(), |
| } |
| METADATA_CACHE.write_text(json.dumps(result, indent=2)) |
| return result |
|
|
|
|
| def get_dataset_metadata(refresh: bool = False) -> dict: |
| if METADATA_CACHE.exists() and not refresh: |
| return json.loads(METADATA_CACHE.read_text()) |
| return compute_dataset_metadata() |
|
|
|
|
| def list_datasets(refresh: bool = False) -> dict: |
| """Return cached dataset metadata instead of loading full tables.""" |
| return get_dataset_metadata(refresh=refresh) |
|
|
|
|
| def sample_rows( |
| dataset: str, |
| n: int = 5, |
| filters: dict | None = None, |
| columns: list[str] | None = None, |
| ) -> dict: |
| """Return a small deterministic preview of rows from a dataset, optionally filtered.""" |
| selected_columns = columns or _schema_names(dataset) |
| table = _scanner(dataset, columns=selected_columns, filters=filters).head(n) |
| df = table.to_pandas() if table.num_rows else pd.DataFrame(columns=selected_columns) |
| return { |
| "dataset": dataset, |
| "filters": filters or {}, |
| "n_returned": len(df), |
| "rows": df.fillna("").to_dict(orient="records"), |
| } |
|
|
|
|
| |
| |
| |
|
|
| @dataclass(frozen=True) |
| class RoutePlan: |
| mode: str |
| allowed_tools: list[str] |
| guidance: str |
|
|
|
|
| ALL_TOOL_NAMES = [ |
| "list_datasets", "sample_rows", "count_by_group", "trend_over_time", |
| "summary_stats", "top_posts", "text_search", "word_freq", "compare_groups", |
| "extract_frequency_patterns", "extract_dominance_patterns", "analyze_image_sample", |
| "export_reliability_sample", "compute_reliability", |
| ] |
|
|
|
|
| def route_question(question: str) -> RoutePlan: |
| q = question.lower() |
|
|
| if any(t in q for t in ["image", "images", "photo", "photos", "visual", "depicted"]): |
| return RoutePlan( |
| mode="image", |
| allowed_tools=["list_datasets", "sample_rows", "analyze_image_sample", "export_reliability_sample", "compute_reliability"], |
| guidance="This is a visual-content question. Prefer image analysis tools and avoid text-only proxies. Always provide a coding_scheme.", |
| ) |
| if any(t in q for t in ["reliability", "kappa", "human coding", "inter-rater", "validate"]): |
| return RoutePlan( |
| mode="reliability", |
| allowed_tools=["export_reliability_sample", "compute_reliability"], |
| guidance="This is a reliability/validation question. Use export_reliability_sample then compute_reliability.", |
| ) |
| if any(t in q for t in ["how often", "how long", "times per", "every day", "session length", "streak"]): |
| return RoutePlan( |
| mode="pattern_frequency", |
| allowed_tools=["list_datasets", "sample_rows", "extract_frequency_patterns", "text_search"], |
| guidance="This is a behavioral frequency/duration question. Prefer regex pattern extraction over generic word counts.", |
| ) |
| if any(t in q for t in ["dominant", "subordinate", "mistress", "goddess", "femdom", "submissive"]): |
| return RoutePlan( |
| mode="pattern_dominance", |
| allowed_tools=["list_datasets", "sample_rows", "extract_dominance_patterns", "text_search", "analyze_image_sample"], |
| guidance="This is a dominance/subordination framing question. Use the text pattern tool unless the user explicitly asks about images.", |
| ) |
| if any(t in q for t in ["over time", "trend", "changed", "change over time", "monthly", "yearly"]): |
| return RoutePlan( |
| mode="trend", |
| allowed_tools=["list_datasets", "sample_rows", "trend_over_time", "count_by_group"], |
| guidance="This is a time-series question. Prefer trend_over_time and only use grouping/count tools to contextualize it.", |
| ) |
| if any(t in q for t in ["compare", "difference", "versus", "vs", "higher", "lower"]): |
| return RoutePlan( |
| mode="compare", |
| allowed_tools=["list_datasets", "sample_rows", "compare_groups", "summary_stats", "count_by_group"], |
| guidance="This is a comparison question. Prefer compare_groups or summary_stats with explicit filters.", |
| ) |
| if any(t in q for t in ["top", "highest", "best scoring", "most upvoted"]): |
| return RoutePlan( |
| mode="ranking", |
| allowed_tools=["list_datasets", "sample_rows", "top_posts", "summary_stats"], |
| guidance="This is a ranking question. Prefer top_posts and use summary_stats only if it supports the answer.", |
| ) |
| if any(t in q for t in ["search", "find", "mention", "contains", "where people say"]): |
| return RoutePlan( |
| mode="search", |
| allowed_tools=["list_datasets", "sample_rows", "text_search", "top_posts"], |
| guidance="This is a retrieval question. Prefer text_search with the right dataset and text column.", |
| ) |
| if any(t in q for t in ["common words", "most common words", "word frequency", "tokens"]): |
| return RoutePlan( |
| mode="lexical", |
| allowed_tools=["list_datasets", "sample_rows", "word_freq", "text_search"], |
| guidance="This is a lexical summary question. Prefer word_freq and inspect text samples only if needed.", |
| ) |
| if any(t in q for t in ["how many", "count", "number of", "what proportion"]): |
| return RoutePlan( |
| mode="describe", |
| allowed_tools=["list_datasets", "sample_rows", "count_by_group", "summary_stats", "trend_over_time"], |
| guidance="This is a descriptive count question. Prefer count_by_group or summary_stats and keep the plan minimal.", |
| ) |
| return RoutePlan( |
| mode="unknown", |
| allowed_tools=ALL_TOOL_NAMES, |
| guidance="Question type is ambiguous. Inspect metadata first, then choose the minimum reliable tool path.", |
| ) |
|
|
|
|
| |
| |
| |
|
|
| VISION_MODEL = "Qwen/Qwen2-VL-72B-Instruct" |
| TOGETHER_BASE_URL = "https://api.together.xyz/v1" |
| DIRECT_IMAGE_DOMAINS = {"i.redd.it", "i.imgur.com", "i.redgifs.com"} |
|
|
|
|
| def _load_image_urls(subreddit: str | None = None, n: int = 50) -> pd.DataFrame: |
| pattern = str(DATA_DIR / "*_submissions_*.csv") |
| files = sorted(glob.glob(pattern)) |
| if subreddit: |
| files = [f for f in files if Path(f).name.lower().startswith(subreddit.lower())] |
|
|
| needed_cols = ["subreddit", "title", "url", "domain", "score", "is_self"] |
| frames = [] |
| for f in files: |
| try: |
| df = pd.read_csv(f, usecols=lambda c: c in needed_cols, low_memory=False) |
| if "is_self" in df.columns: |
| df = df[df["is_self"] == False] |
| if "url" in df.columns and "domain" in df.columns: |
| df = df[df["domain"].isin(DIRECT_IMAGE_DOMAINS)].dropna(subset=["url"]) |
| frames.append(df[["subreddit", "title", "url", "domain", "score"]]) |
| except Exception: |
| continue |
|
|
| if not frames: |
| return pd.DataFrame() |
| combined = pd.concat(frames, ignore_index=True) |
| if len(combined) > n * 10: |
| combined = combined.sample(min(n * 10, len(combined)), random_state=42) |
| return combined.head(n * 10) |
|
|
|
|
| def _fetch_image_b64(url: str, timeout: int = 8) -> tuple[str, str] | None: |
| try: |
| url.encode("ascii") |
| except UnicodeEncodeError: |
| return None |
| try: |
| req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0 (research bot)"}) |
| with urllib.request.urlopen(req, timeout=timeout) as resp: |
| content_type = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip() |
| if not content_type.startswith("image/"): |
| return None |
| data = resp.read() |
| if len(data) < 1000: |
| return None |
| return base64.standard_b64encode(data).decode("utf-8"), content_type |
| except Exception: |
| return None |
|
|
|
|
| def analyze_image_sample( |
| question: str, |
| subreddit: str | None = None, |
| n_sample: int = 100, |
| coding_scheme: dict | None = None, |
| ) -> dict: |
| """ |
| Sample image posts, fetch them, and ask Qwen2-VL a structured content-analysis question. |
| Uses Together AI (no content filters). n_sample is uncapped β set as needed. |
| """ |
| client = openai.OpenAI( |
| api_key=os.environ["TOGETHER_API_KEY"], |
| base_url=TOGETHER_BASE_URL, |
| ) |
| candidates = _load_image_urls(subreddit=subreddit, n=n_sample * 5) |
|
|
| if candidates.empty: |
| return { |
| "analysis": "analyze_image_sample", |
| "error": "No direct image URLs found in raw CSVs for the given filters.", |
| "subreddit_filter": subreddit, |
| } |
|
|
| if coding_scheme: |
| scheme_text = "\n".join(f"- {k}: {v}" for k, v in coding_scheme.items()) |
| prompt = ( |
| f"{question}\n\nCoding scheme:\n{scheme_text}\n\n" |
| "Reply with ONLY the label and a one-sentence justification, " |
| "formatted as: LABEL | justification" |
| ) |
| else: |
| prompt = ( |
| f"{question}\n\n" |
| "Reply with a short structured answer. " |
| "If you cannot determine this from the image, reply: UNCLEAR | reason" |
| ) |
|
|
| results = [] |
| attempted = 0 |
|
|
| for _, row in candidates.iterrows(): |
| if len(results) >= n_sample: |
| break |
| attempted += 1 |
| img = _fetch_image_b64(row["url"]) |
| if img is None: |
| continue |
| b64data, media_type = img |
|
|
| try: |
| response = client.chat.completions.create( |
| model=VISION_MODEL, |
| max_tokens=200, |
| messages=[{ |
| "role": "user", |
| "content": [ |
| {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{b64data}"}}, |
| {"type": "text", "text": prompt}, |
| ], |
| }], |
| ) |
| answer = response.choices[0].message.content.strip() |
| except Exception as e: |
| answer = f"ERROR | {e}" |
|
|
| parts = answer.split("|", 1) |
| label = parts[0].strip().upper() if parts else "UNCLEAR" |
| justification = parts[1].strip() if len(parts) > 1 else "" |
|
|
| def _ascii_safe(s: str) -> str: |
| return s.encode("ascii", errors="replace").decode("ascii") |
|
|
| results.append({ |
| "subreddit": _ascii_safe(str(row.get("subreddit", ""))), |
| "title": _ascii_safe(str(row.get("title", ""))), |
| "url": row["url"], |
| "label": label, |
| "justification": _ascii_safe(justification), |
| "score": row.get("score", None), |
| }) |
|
|
| label_counts: dict[str, int] = {} |
| for r in results: |
| label_counts[r["label"]] = label_counts.get(r["label"], 0) + 1 |
|
|
| total_coded = len(results) |
| saved_csv = None |
| if results: |
| out_df = pd.DataFrame(results) |
| stem = f"image_analysis_{subreddit or 'all'}" |
| saved_csv = str(OUTPUTS_DIR / f"{stem}.csv") |
| out_df.to_csv(saved_csv, index=False) |
|
|
| return { |
| "analysis": "analyze_image_sample", |
| "question": question, |
| "subreddit_filter": subreddit, |
| "n_attempted": attempted, |
| "n_successfully_coded": total_coded, |
| "label_counts": label_counts, |
| "label_pct": {k: round(v / total_coded * 100, 1) for k, v in label_counts.items()} if total_coded else {}, |
| "per_image_results": results, |
| "saved_csv": saved_csv, |
| "caveats": [ |
| "Sample limited to direct-image domains (i.redd.it, i.imgur.com, i.redgifs.com) β galleries and videos excluded.", |
| f"Vision model: {VISION_MODEL} via Together AI.", |
| "Coded by a single model β validate with human reliability sample before reporting.", |
| ], |
| } |
|
|
|
|
| def export_reliability_sample( |
| source_csv: str | None = None, |
| n: int = 200, |
| random_state: int = 42, |
| ) -> dict: |
| """ |
| Draw a stratified random sample of n images from a completed image_analysis CSV |
| for human coding. Saves a CSV with an empty human_label column. |
| """ |
| if source_csv is None: |
| |
| candidates = sorted(OUTPUTS_DIR.glob("image_analysis_*.csv")) |
| if not candidates: |
| return {"error": "No image_analysis CSV found in outputs/. Run analyze_image_sample first."} |
| source_csv = str(candidates[-1]) |
|
|
| df = pd.read_csv(source_csv) |
| df = df[df["label"].notna() & ~df["label"].str.startswith("ERROR")] |
|
|
| |
| sampled = ( |
| df.groupby("label", group_keys=False) |
| .apply(lambda g: g.sample(min(len(g), max(1, int(n * len(g) / len(df)))), random_state=random_state)) |
| ) |
| |
| if len(sampled) < n and len(df) >= n: |
| remaining = df[~df.index.isin(sampled.index)] |
| top_up = remaining.sample(n - len(sampled), random_state=random_state) |
| sampled = pd.concat([sampled, top_up]) |
|
|
| sampled = sampled.sample(frac=1, random_state=random_state).reset_index(drop=True) |
| sampled.insert(0, "image_id", range(1, len(sampled) + 1)) |
| sampled = sampled.rename(columns={"label": "model_label", "justification": "model_justification"}) |
| sampled["human_label"] = "" |
|
|
| out_cols = ["image_id", "url", "title", "subreddit", "model_label", "model_justification", "human_label"] |
| out_cols = [c for c in out_cols if c in sampled.columns] |
| out_path = str(OUTPUTS_DIR / "reliability_sample.csv") |
| sampled[out_cols].to_csv(out_path, index=False) |
|
|
| return { |
| "analysis": "export_reliability_sample", |
| "source_csv": source_csv, |
| "n_exported": len(sampled), |
| "label_distribution": sampled["model_label"].value_counts().to_dict(), |
| "saved_csv": out_path, |
| "next_step": "Fill in the human_label column, then run compute_reliability.", |
| } |
|
|
|
|
| def compute_reliability(human_csv_path: str | None = None) -> dict: |
| """ |
| Compute Cohen's kappa between model_label and human_label columns |
| in a completed reliability sample CSV. |
| """ |
| if human_csv_path is None: |
| human_csv_path = str(OUTPUTS_DIR / "reliability_sample.csv") |
|
|
| df = pd.read_csv(human_csv_path) |
| df = df[df["human_label"].notna() & (df["human_label"].astype(str).str.strip() != "")] |
|
|
| if len(df) < 2: |
| return {"error": "Not enough human-coded rows. Fill in human_label column first."} |
|
|
| model = df["model_label"].astype(str).str.strip().str.upper() |
| human = df["human_label"].astype(str).str.strip().str.upper() |
|
|
| kappa = cohen_kappa_score(human, model) |
| pct_agreement = round((human == model).mean() * 100, 1) |
|
|
| per_label = {} |
| for label in sorted(human.unique()): |
| h = (human == label) |
| m = (model == label) |
| tp = int((h & m).sum()) |
| fp = int((~h & m).sum()) |
| fn = int((h & ~m).sum()) |
| per_label[label] = {"human_n": int(h.sum()), "model_n": int(m.sum()), |
| "exact_matches": tp, "false_positives": fp, "false_negatives": fn} |
|
|
| report = { |
| "analysis": "compute_reliability", |
| "n_coded": len(df), |
| "cohens_kappa": round(kappa, 3), |
| "percent_agreement": pct_agreement, |
| "interpretation": ( |
| "excellent (ΞΊ β₯ 0.80)" if kappa >= 0.80 else |
| "substantial (ΞΊ 0.60β0.79)" if kappa >= 0.60 else |
| "moderate (ΞΊ 0.40β0.59)" if kappa >= 0.40 else |
| "fair (ΞΊ 0.20β0.39)" if kappa >= 0.20 else |
| "poor (ΞΊ < 0.20)" |
| ), |
| "per_label": per_label, |
| } |
|
|
| out_path = str(OUTPUTS_DIR / "reliability_report.json") |
| Path(out_path).write_text(json.dumps(report, indent=2)) |
| report["saved_json"] = out_path |
| return report |
|
|
|
|
| |
| |
| |
|
|
| FREQUENCY_PATTERNS = { |
| "times_per_day": [ |
| r"\b(\d+)\s*(?:times?|x)\s*(?:a|per)\s*day\b", |
| r"\b(\d+)\s*(?:times?|x)\s*daily\b", |
| ], |
| "times_per_week": [ |
| r"\b(\d+)\s*(?:times?|x)\s*(?:a|per)\s*week\b", |
| r"\b(\d+)\s*(?:times?|x)\s*weekly\b", |
| ], |
| "hours_per_session": [ |
| r"\b(\d+(?:\.\d+)?)\s*(?:hours?|hrs?)\b", |
| r"\b(\d+(?:\.\d+)?)\s*(?:hours?|hrs?)\s*(?:session|goon|long|straight|solid|non.?stop)\b", |
| ], |
| "all_day": [ |
| r"\ball\s*day\b", r"\ball\s*night\b", r"\ball\s*weekend\b", r"\bfor\s*hours\b", |
| ], |
| "daily_habit": [ |
| r"\bevery\s*day\b", r"\bevery\s*night\b", r"\bdaily\b", r"\bmost\s*days?\b", |
| ], |
| "streak_days": [ |
| r"\b(\d+)\s*(?:days?\s*(?:in\s*a\s*row|straight|streak|running))\b", |
| r"\b(\d+)\s*(?:-|β)?\s*day\s*(?:streak|binge)\b", |
| ], |
| } |
|
|
| DOMINANCE_PATTERNS = { |
| "dominant_language": [ |
| r"\bdominat(?:e|es|ed|ing|ion|rix|rix)\b", r"\bfem(?:dom|domme)\b", |
| r"\bmistress\b", r"\bgoddess\b", r"\bqueen\b", r"\bowner\b", |
| r"\balpha\b", r"\bin\s*control\b", r"\bboss\b", |
| ], |
| "subordinate_language": [ |
| r"\bsubmiss(?:ive|ion)\b", r"\bsub\b", r"\bobedient\b", r"\bslave\b", |
| r"\bpet\b", r"\bslut\b", r"\bwhore\b", r"\bused\b", |
| r"\bcontrolled\b", r"\bworshiped?\b", r"\bworship(?:ped|ing)\b", |
| ], |
| "neutral_object": [ |
| r"\bperfect\b", r"\bbeautiful\b", r"\bhot\b", r"\bsexy\b", r"\bstunning\b", |
| ], |
| } |
|
|
|
|
| def _compile(patterns: list[str]) -> re.Pattern: |
| return re.compile("|".join(patterns), re.IGNORECASE) |
|
|
|
|
| def extract_frequency_patterns( |
| dataset: str = "comments", |
| text_col: str = "body", |
| subreddit: str | None = None, |
| n_examples: int = 5, |
| sample_size: int = 5_000_000, |
| ) -> dict: |
| """Mine text for frequency and duration language across the full dataset.""" |
| cols = [text_col] + (["subreddit"] if subreddit else []) |
| df = _scanner( |
| dataset, columns=cols, |
| filters={"subreddit": subreddit} if subreddit else None, |
| ).head(sample_size).to_pandas() |
|
|
| text = df[text_col].fillna("") |
| total_docs = len(text) |
| results = {} |
|
|
| for category, pats in FREQUENCY_PATTERNS.items(): |
| regex = _compile(pats) |
| matches_mask = text.str.contains(regex.pattern, regex=True, na=False) |
| hit_texts = text[matches_mask] |
| values = [] |
| for pat in pats: |
| r = re.compile(pat, re.IGNORECASE) |
| for t in hit_texts: |
| for m in r.finditer(t): |
| if m.groups(): |
| try: |
| values.append(float(m.group(1))) |
| except (IndexError, ValueError): |
| pass |
| raw_examples = hit_texts.sample(min(n_examples, len(hit_texts)), random_state=42).tolist() if len(hit_texts) > 0 else [] |
| results[category] = { |
| "count": int(matches_mask.sum()), |
| "pct_of_docs": round(matches_mask.mean() * 100, 3), |
| "numeric_values": sorted(values)[:50] if values else [], |
| "mean_value": round(sum(values) / len(values), 2) if values else None, |
| "examples": [t.encode("ascii", errors="replace").decode("ascii") for t in raw_examples], |
| } |
|
|
| return { |
| "analysis": "extract_frequency_patterns", |
| "dataset": dataset, |
| "text_col": text_col, |
| "subreddit_filter": subreddit, |
| "total_docs_sampled": total_docs, |
| "patterns": results, |
| } |
|
|
|
|
| def extract_dominance_patterns( |
| dataset: str = "comments", |
| text_col: str = "body", |
| subreddit: str | None = None, |
| sample_size: int = 5_000_000, |
| ) -> dict: |
| """Count dominant, subordinate, and neutral language in text.""" |
| cols = [text_col] + (["subreddit"] if subreddit else []) |
| df = _scanner( |
| dataset, columns=cols, |
| filters={"subreddit": subreddit} if subreddit else None, |
| ).head(sample_size).to_pandas() |
|
|
| text = df[text_col].fillna("") |
| total_docs = len(text) |
| results = {} |
|
|
| for category, pats in DOMINANCE_PATTERNS.items(): |
| regex = _compile(pats) |
| mask = text.str.contains(regex, na=False) |
| hits = text[mask] |
| raw_examples = hits.sample(min(5, len(hits)), random_state=42).tolist() if len(hits) > 0 else [] |
| results[category] = { |
| "count": int(mask.sum()), |
| "pct_of_docs": round(mask.mean() * 100, 3), |
| "examples": [t.encode("ascii", errors="replace").decode("ascii") for t in raw_examples], |
| } |
|
|
| dom = results.get("dominant_language", {}).get("count", 0) |
| sub = results.get("subordinate_language", {}).get("count", 0) |
| total = dom + sub |
| ratio = { |
| "dominant_pct": round(dom / total * 100, 1) if total else None, |
| "subordinate_pct": round(sub / total * 100, 1) if total else None, |
| "interpretation": ( |
| "More subordinate language" if sub > dom else |
| "More dominant language" if dom > sub else |
| "Roughly balanced" |
| ) if total else "No data", |
| } |
|
|
| return { |
| "analysis": "extract_dominance_patterns", |
| "dataset": dataset, |
| "text_col": text_col, |
| "subreddit_filter": subreddit, |
| "total_docs_sampled": total_docs, |
| "categories": results, |
| "dominance_ratio": ratio, |
| "caveat": ( |
| "This analysis counts language patterns in text, not visual image content. " |
| "It reflects how women are described in text, not how they appear in images. " |
| "For image-based analysis use analyze_image_sample." |
| ), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _ts_to_date(series: pd.Series) -> pd.Series: |
| return pd.to_datetime(series, unit="s", utc=True) |
|
|
|
|
| def _normalize_filters( |
| filters: dict | None = None, |
| subreddit: str | None = None, |
| date_from: str | None = None, |
| date_to: str | None = None, |
| min_score: float | None = None, |
| ) -> dict: |
| merged = dict(filters or {}) |
| if subreddit: |
| merged["subreddit"] = subreddit |
| if date_from: |
| merged["date_from"] = date_from |
| if date_to: |
| merged["date_to"] = date_to |
| if min_score is not None: |
| merged["min_score"] = min_score |
| return merged |
|
|
|
|
| def _apply_filters(df: pd.DataFrame, filters: dict | None = None) -> pd.DataFrame: |
| if not filters: |
| return df |
| filtered = df |
| if filters.get("subreddit") and "subreddit" in filtered.columns: |
| filtered = filtered[filtered["subreddit"] == filters["subreddit"]] |
| if filters.get("author") and "author" in filtered.columns: |
| filtered = filtered[filtered["author"] == filters["author"]] |
| if filters.get("min_score") is not None and "score" in filtered.columns: |
| filtered["score"] = pd.to_numeric(filtered["score"], errors="coerce") |
| filtered = filtered[filtered["score"] >= filters["min_score"]] |
| if ("date_from" in filters or "date_to" in filters) and "created_utc" in filtered.columns: |
| created = _ts_to_date(filtered["created_utc"]) |
| if filters.get("date_from"): |
| filtered = filtered[created >= pd.Timestamp(filters["date_from"], tz="UTC")] |
| created = _ts_to_date(filtered["created_utc"]) |
| if filters.get("date_to"): |
| filtered = filtered[created <= pd.Timestamp(filters["date_to"], tz="UTC")] |
| return filtered |
|
|
|
|
| def _save_csv(df: pd.DataFrame, stem: str) -> str: |
| path = OUTPUTS_DIR / f"{stem}.csv" |
| df.to_csv(path, index=False) |
| return str(path) |
|
|
|
|
| def _save_fig(fig, stem: str) -> str: |
| path = OUTPUTS_DIR / f"{stem}.png" |
| pio.write_image(fig, str(path), scale=2) |
| return str(path) |
|
|
|
|
| def count_by_group(dataset: str, group_col: str, top_n: int = 30, filters: dict | None = None) -> dict: |
| """Count rows grouped by a column. Returns sorted table + bar chart.""" |
| filter_cols = [c for c in ["subreddit", "author", "score", "created_utc"] if c != group_col] |
| df = _load(dataset, columns=list(dict.fromkeys([group_col] + filter_cols))) |
| df = _apply_filters(df, filters) |
| counts = ( |
| df.groupby(group_col, dropna=False) |
| .size().reset_index(name="count") |
| .sort_values("count", ascending=False).head(top_n) |
| ) |
| stem = f"count_by_{group_col}_{dataset}" |
| saved_csv = _save_csv(counts, stem) |
| fig = px.bar( |
| counts.sort_values("count"), x="count", y=group_col, orientation="h", |
| title=f"Count by {group_col}", labels={"count": "Count", group_col: group_col}, |
| ) |
| fig.update_layout(yaxis={"categoryorder": "total ascending"}) |
| try: |
| saved_png = _save_fig(fig, stem) |
| except Exception: |
| saved_png = None |
| return { |
| "analysis": "count_by_group", "dataset": dataset, "group_col": group_col, |
| "filters": filters or {}, "total_rows": len(df), |
| "table": counts.to_dict(orient="records"), |
| "saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": fig.to_json(), |
| } |
|
|
|
|
| def trend_over_time( |
| dataset: str, freq: str = "M", group_col: str | None = None, |
| top_groups: int = 8, filters: dict | None = None, |
| ) -> dict: |
| """Count posts/comments over time, optionally broken out by a grouping column.""" |
| cols = ["created_utc"] + ([group_col] if group_col else []) + ["subreddit", "author", "score"] |
| cols = list(dict.fromkeys(cols)) |
| df = _load(dataset, columns=cols) |
| df = _apply_filters(df, filters) |
| df["period"] = _ts_to_date(df["created_utc"]).dt.to_period(freq).astype(str) |
|
|
| if group_col: |
| top = df[group_col].value_counts().head(top_groups).index.tolist() |
| df = df[df[group_col].isin(top)] |
| counts = ( |
| df.groupby(["period", group_col]).size() |
| .reset_index(name="count").sort_values("period") |
| ) |
| fig = px.line(counts, x="period", y="count", color=group_col, |
| title=f"Activity over time by {group_col}") |
| else: |
| counts = ( |
| df.groupby("period").size() |
| .reset_index(name="count").sort_values("period") |
| ) |
| fig = px.line(counts, x="period", y="count", title="Activity over time") |
|
|
| stem = f"trend_{dataset}_{group_col or 'all'}_{freq}" |
| saved_csv = _save_csv(counts, stem) |
| try: |
| saved_png = _save_fig(fig, stem) |
| except Exception: |
| saved_png = None |
| return { |
| "analysis": "trend_over_time", "dataset": dataset, "freq": freq, |
| "group_col": group_col, "filters": filters or {}, |
| "table": counts.to_dict(orient="records"), |
| "saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": fig.to_json(), |
| } |
|
|
|
|
| def summary_stats( |
| dataset: str, value_col: str, group_col: str | None = None, |
| top_n: int = 30, filters: dict | None = None, |
| ) -> dict: |
| """Descriptive statistics for a numeric column, optionally by group.""" |
| cols = [value_col] + ([group_col] if group_col else []) + ["subreddit", "author", "score", "created_utc"] |
| cols = list(dict.fromkeys(cols)) |
| df = _load(dataset, columns=cols) |
| df = _apply_filters(df, filters) |
| df[value_col] = pd.to_numeric(df[value_col], errors="coerce") |
|
|
| if group_col: |
| stats = ( |
| df.groupby(group_col)[value_col] |
| .agg(["count", "mean", "median", "std", "min", "max"]) |
| .reset_index().sort_values("mean", ascending=False).head(top_n).round(2) |
| ) |
| else: |
| raw = df[value_col].describe().round(2) |
| stats = raw.reset_index() |
| stats.columns = ["stat", "value"] |
|
|
| stem = f"stats_{value_col}_{group_col or 'all'}_{dataset}" |
| saved_csv = _save_csv(stats, stem) |
| try: |
| if group_col: |
| fig = px.bar(stats, x=group_col, y="mean", error_y="std", |
| title=f"{value_col} by {group_col}", |
| labels={"mean": f"Mean {value_col}"}) |
| else: |
| fig = px.histogram(df[value_col].dropna(), nbins=50, |
| title=f"Distribution of {value_col}", |
| labels={"value": value_col}) |
| saved_png = _save_fig(fig, stem) |
| plotly_json = fig.to_json() |
| except Exception: |
| saved_png = None |
| plotly_json = None |
| return { |
| "analysis": "summary_stats", "dataset": dataset, "value_col": value_col, |
| "group_col": group_col, "filters": filters or {}, |
| "n_total": len(df), "n_missing": int(df[value_col].isna().sum()), |
| "table": stats.to_dict(orient="records"), |
| "saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": plotly_json, |
| } |
|
|
|
|
| def top_posts( |
| dataset: str = "posts", n: int = 20, |
| subreddit: str | None = None, text_col: str = "title", |
| filters: dict | None = None, |
| ) -> dict: |
| """Return the highest-scoring posts, optionally filtered to a subreddit.""" |
| filters = _normalize_filters(filters=filters, subreddit=subreddit) |
| cols = [c for c in ["subreddit", "author", text_col, "score", "created_utc"] if c] |
| df = _load(dataset, columns=cols) |
| df = _apply_filters(df, filters) |
| top = df.nlargest(n, "score")[cols].copy() |
| top["date"] = _ts_to_date(top["created_utc"]).dt.strftime("%Y-%m-%d") |
| top = top.drop(columns=["created_utc"]) |
| stem = f"top_posts_{subreddit or 'all'}_{dataset}" |
| saved = _save_csv(top, stem) |
| return { |
| "analysis": "top_posts", "dataset": dataset, |
| "subreddit_filter": subreddit, "filters": filters, "n": n, |
| "table": top.fillna("").to_dict(orient="records"), "saved_csv": saved, |
| } |
|
|
|
|
| def text_search( |
| dataset: str, query: str, text_col: str = "body", |
| n: int = 20, case_sensitive: bool = False, |
| subreddit: str | None = None, filters: dict | None = None, |
| ) -> dict: |
| """Search for a string pattern in a text column.""" |
| filters = _normalize_filters(filters=filters, subreddit=subreddit) |
| cols = [c for c in ["subreddit", "author", text_col, "score", "created_utc"] if c] |
| df = _load(dataset, columns=cols) |
| df = _apply_filters(df, filters) |
| mask = df[text_col].fillna("").str.contains(query, case=case_sensitive, regex=False) |
| hits = df[mask].nlargest(n, "score").copy() |
| hits["date"] = _ts_to_date(hits["created_utc"]).dt.strftime("%Y-%m-%d") |
| hits = hits.drop(columns=["created_utc"]) |
| stem = f"search_{query[:30].replace(' ', '_')}_{dataset}" |
| saved = _save_csv(hits, stem) |
| return { |
| "analysis": "text_search", "dataset": dataset, "query": query, |
| "text_col": text_col, "filters": filters, |
| "n_matches": int(mask.sum()), "n_returned": len(hits), |
| "table": hits.fillna("").to_dict(orient="records"), "saved_csv": saved, |
| } |
|
|
|
|
| def word_freq( |
| dataset: str = "corpus_clean", text_col: str = "text_cleaned", |
| top_n: int = 50, subreddit: str | None = None, |
| min_length: int = 4, filters: dict | None = None, |
| ) -> dict: |
| """Count word frequencies in a text column.""" |
| filters = _normalize_filters(filters=filters, subreddit=subreddit) |
| cols = list(dict.fromkeys([text_col] + (["subreddit"] if subreddit else []) + ["author", "score", "created_utc"])) |
| df = _load(dataset, columns=cols) |
| df = _apply_filters(df, filters) |
|
|
| stop = { |
| "the","and","for","that","with","this","you","are","was","not", |
| "have","from","they","will","what","been","when","your","more", |
| "just","about","like","there","were","would","into","than","then", |
| "some","also","very","only","over","back","can","out","all","but", |
| "one","had","has","its","which","their","time","our","who","may", |
| "after","other","these","those","such","each","him","her","his", |
| "she","how","did","being","now","way","any","too","much","even", |
| "get","got","got","could","should","make","made","said","still", |
| "here","because","really","know","think","going","reddit","post", |
| "comment","deleted","removed", |
| } |
|
|
| words = ( |
| df[text_col].fillna("").str.lower() |
| .str.replace(r"[^a-z\s]", " ", regex=True).str.split().explode() |
| ) |
| words = words[words.str.len() >= min_length] |
| words = words[~words.isin(stop)] |
| counts = words.value_counts().head(top_n).reset_index() |
| counts.columns = ["word", "count"] |
| stem = f"wordfreq_{text_col}_{subreddit or 'all'}_{dataset}" |
| saved_csv = _save_csv(counts, stem) |
| fig = px.bar( |
| counts.head(30).sort_values("count"), x="count", y="word", orientation="h", |
| title="Top words by frequency", labels={"count": "Count", "word": "Word"}, |
| ) |
| fig.update_layout(yaxis={"categoryorder": "total ascending"}) |
| try: |
| saved_png = _save_fig(fig, stem) |
| except Exception: |
| saved_png = None |
| return { |
| "analysis": "word_freq", "dataset": dataset, "text_col": text_col, |
| "subreddit_filter": subreddit, "filters": filters, "total_docs": len(df), |
| "table": counts.to_dict(orient="records"), |
| "saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": fig.to_json(), |
| } |
|
|
|
|
| def compare_groups( |
| dataset: str, group_col: str, value_col: str, |
| groups: list[str] | None = None, filters: dict | None = None, |
| ) -> dict: |
| """Compare a numeric value across groups with descriptive stats.""" |
| cols = list(dict.fromkeys([group_col, value_col, "subreddit", "author", "score", "created_utc"])) |
| df = _load(dataset, columns=cols) |
| df = _apply_filters(df, filters) |
| df[value_col] = pd.to_numeric(df[value_col], errors="coerce") |
| if groups: |
| df = df[df[group_col].isin(groups)] |
| stats = ( |
| df.groupby(group_col)[value_col] |
| .agg(count="count", mean="mean", median="median", std="std") |
| .reset_index().sort_values("median", ascending=False).round(3) |
| ) |
| stem = f"compare_{group_col}_{value_col}_{dataset}" |
| saved_csv = _save_csv(stats, stem) |
| fig = px.bar(stats, x=group_col, y="median", error_y="std", |
| title=f"{value_col} by {group_col} (median Β± std)", |
| labels={"median": f"Median {value_col}"}) |
| try: |
| saved_png = _save_fig(fig, stem) |
| except Exception: |
| saved_png = None |
| return { |
| "analysis": "compare_groups", "dataset": dataset, |
| "group_col": group_col, "value_col": value_col, |
| "filters": filters or {}, "groups_compared": stats[group_col].tolist(), |
| "table": stats.to_dict(orient="records"), |
| "saved_csv": saved_csv, "saved_png": saved_png, "plotly_json": fig.to_json(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| MODEL = "claude-opus-4-6" |
|
|
| TOOLS = [ |
| { |
| "name": "list_datasets", |
| "description": ( |
| "List cached dataset metadata: paths, row counts, columns, subreddits, and date ranges. " |
| "Use this to inspect the available data without loading full tables." |
| ), |
| "input_schema": { |
| "type": "object", |
| "properties": {"refresh": {"type": "boolean", "default": False, |
| "description": "Recompute metadata from source parquets instead of using the cache."}}, |
| "required": [], |
| }, |
| }, |
| { |
| "name": "sample_rows", |
| "description": "Return a small deterministic preview of rows from a dataset, optionally filtered and column-limited.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]}, |
| "n": {"type": "integer", "default": 5}, |
| "filters": {"type": "object", "description": "Optional equality filters, e.g. {\"subreddit\": \"GOONED\"}"}, |
| "columns": {"type": "array", "items": {"type": "string"}, "description": "Optional subset of columns to preview."}, |
| }, |
| "required": ["dataset"], |
| }, |
| }, |
| { |
| "name": "count_by_group", |
| "description": "Count rows in a dataset grouped by one column, with optional shared filters.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]}, |
| "group_col": {"type": "string"}, |
| "top_n": {"type": "integer", "default": 30}, |
| "filters": {"type": "object"}, |
| }, |
| "required": ["dataset", "group_col"], |
| }, |
| }, |
| { |
| "name": "trend_over_time", |
| "description": "Count rows over time, optionally split by one grouping column, with optional shared filters.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]}, |
| "freq": {"type": "string", "enum": ["D", "W", "M", "Q", "Y"], "default": "M"}, |
| "group_col": {"type": "string"}, |
| "top_groups": {"type": "integer", "default": 8}, |
| "filters": {"type": "object"}, |
| }, |
| "required": ["dataset"], |
| }, |
| }, |
| { |
| "name": "summary_stats", |
| "description": "Descriptive statistics for a numeric column, optionally grouped and filtered.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]}, |
| "value_col": {"type": "string"}, |
| "group_col": {"type": "string"}, |
| "top_n": {"type": "integer", "default": 30}, |
| "filters": {"type": "object"}, |
| }, |
| "required": ["dataset", "value_col"], |
| }, |
| }, |
| { |
| "name": "top_posts", |
| "description": "Return the highest-scoring posts, optionally filtered by subreddit or shared filters.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "titles"], "default": "posts"}, |
| "n": {"type": "integer", "default": 20}, |
| "subreddit": {"type": "string"}, |
| "text_col": {"type": "string", "default": "title"}, |
| "filters": {"type": "object"}, |
| }, |
| "required": [], |
| }, |
| }, |
| { |
| "name": "text_search", |
| "description": "Search for a phrase in a text column and return top matching rows.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]}, |
| "query": {"type": "string"}, |
| "text_col": {"type": "string", "default": "body"}, |
| "n": {"type": "integer", "default": 20}, |
| "subreddit": {"type": "string"}, |
| "filters": {"type": "object"}, |
| }, |
| "required": ["dataset", "query"], |
| }, |
| }, |
| { |
| "name": "word_freq", |
| "description": "Count word frequencies in a text column with optional shared filters.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"], "default": "corpus_clean"}, |
| "text_col": {"type": "string", "default": "text_cleaned"}, |
| "top_n": {"type": "integer", "default": 50}, |
| "subreddit": {"type": "string"}, |
| "min_length": {"type": "integer", "default": 4}, |
| "filters": {"type": "object"}, |
| }, |
| "required": [], |
| }, |
| }, |
| { |
| "name": "compare_groups", |
| "description": "Compare one numeric column across groups with optional shared filters.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean", "titles"]}, |
| "group_col": {"type": "string"}, |
| "value_col": {"type": "string"}, |
| "groups": {"type": "array", "items": {"type": "string"}}, |
| "filters": {"type": "object"}, |
| }, |
| "required": ["dataset", "group_col", "value_col"], |
| }, |
| }, |
| { |
| "name": "extract_frequency_patterns", |
| "description": "Mine text for frequency and duration language across the full dataset.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean"], "default": "comments"}, |
| "text_col": {"type": "string", "default": "body"}, |
| "subreddit": {"type": "string"}, |
| "n_examples": {"type": "integer", "default": 5}, |
| "sample_size": {"type": "integer", "default": 5000000}, |
| }, |
| "required": [], |
| }, |
| }, |
| { |
| "name": "extract_dominance_patterns", |
| "description": "Count dominant vs subordinate language in text, not images.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "dataset": {"type": "string", "enum": ["posts", "comments", "corpus_clean"], "default": "comments"}, |
| "text_col": {"type": "string", "default": "body"}, |
| "subreddit": {"type": "string"}, |
| "sample_size": {"type": "integer", "default": 5000000}, |
| }, |
| "required": [], |
| }, |
| }, |
| { |
| "name": "analyze_image_sample", |
| "description": "Run vision coding on a sample of image posts using Qwen2-VL via Together AI (no content filters). Always provide a coding_scheme for research use.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "question": {"type": "string"}, |
| "subreddit": {"type": "string"}, |
| "n_sample": {"type": "integer", "default": 100, "description": "Number of images to code. No hard cap β set to 500+ for large analyses."}, |
| "coding_scheme": {"type": "object", "description": "Dict of {label: definition}. Always provide this for research questions."}, |
| }, |
| "required": ["question"], |
| }, |
| }, |
| { |
| "name": "export_reliability_sample", |
| "description": "Export a stratified random sample of coded images for human validation. Run after analyze_image_sample.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "source_csv": {"type": "string", "description": "Path to image_analysis CSV. Defaults to most recent."}, |
| "n": {"type": "integer", "default": 200}, |
| "random_state": {"type": "integer", "default": 42}, |
| }, |
| "required": [], |
| }, |
| }, |
| { |
| "name": "compute_reliability", |
| "description": "Compute Cohen's kappa between model and human codes after the human_label column has been filled in.", |
| "input_schema": { |
| "type": "object", |
| "properties": { |
| "human_csv_path": {"type": "string", "description": "Path to completed reliability_sample.csv. Defaults to outputs/reliability_sample.csv."}, |
| }, |
| "required": [], |
| }, |
| }, |
| ] |
|
|
| TOOL_FN_MAP = { |
| "list_datasets": lambda args: list_datasets(**args), |
| "sample_rows": lambda args: sample_rows(**args), |
| "count_by_group": lambda args: count_by_group(**args), |
| "trend_over_time": lambda args: trend_over_time(**args), |
| "summary_stats": lambda args: summary_stats(**args), |
| "top_posts": lambda args: top_posts(**args), |
| "text_search": lambda args: text_search(**args), |
| "word_freq": lambda args: word_freq(**args), |
| "compare_groups": lambda args: compare_groups(**args), |
| "extract_frequency_patterns": lambda args: extract_frequency_patterns(**args), |
| "extract_dominance_patterns": lambda args: extract_dominance_patterns(**args), |
| "analyze_image_sample": lambda args: analyze_image_sample(**args), |
| "export_reliability_sample": lambda args: export_reliability_sample(**args), |
| "compute_reliability": lambda args: compute_reliability(**args), |
| } |
|
|
|
|
| def _safe_str(obj: object) -> object: |
| """Recursively encode any non-ASCII strings as JSON-safe escaped text.""" |
| if isinstance(obj, str): |
| return obj.encode("ascii", errors="backslashreplace").decode("ascii") |
| if isinstance(obj, dict): |
| return {k: _safe_str(v) for k, v in obj.items()} |
| if isinstance(obj, list): |
| return [_safe_str(item) for item in obj] |
| return obj |
|
|
|
|
| def _compact_result(result: object) -> dict: |
| if not isinstance(result, dict): |
| return {"value": result} |
| compact = {} |
| for key in ("analysis", "dataset", "group_col", "value_col", "query", "filters", |
| "n_matches", "n_returned", "n_total", "groups_compared", "saved_csv", "saved_png", "error"): |
| if key in result and result.get(key) is not None: |
| compact[key] = result[key] |
| table = result.get("table") |
| if isinstance(table, list): |
| compact["table_preview"] = table[:3] |
| compact["table_rows"] = len(table) |
| return compact |
|
|
|
|
| def _conversation_state_summary(turns: list[dict] | None) -> str: |
| if not turns: |
| return "No prior analytical state." |
| summary = [] |
| for idx, turn in enumerate(turns[-3:], start=1): |
| summary.append({ |
| "turn": idx, |
| "question": _safe_str(turn.get("question", "")), |
| "answer": _safe_str(turn.get("answer", "")), |
| "tool_calls": [ |
| {"tool": tc.get("tool"), "args": _safe_str(tc.get("args", {})), |
| "result": _safe_str(_compact_result(tc.get("result")))} |
| for tc in turn.get("tool_calls", []) |
| ], |
| "artifacts": turn.get("artifacts", []), |
| }) |
| return json.dumps(summary, default=str, indent=2) |
|
|
|
|
| def _tool_names(tools: list[dict]) -> list[str]: |
| return [t["name"] for t in tools] |
|
|
|
|
| def _tool_subset(allowed_tools: list[str]) -> list[dict]: |
| allowed = set(allowed_tools) |
| return [t for t in TOOLS if t["name"] in allowed] |
|
|
|
|
| def _system_prompt(route_mode: str, route_guidance: str, conversation_state: str) -> str: |
| metadata = get_dataset_metadata() |
| dataset_lines = [] |
| for name, info in metadata.items(): |
| if not info.get("available"): |
| continue |
| date_range = info.get("date_range") or {} |
| dataset_lines.append( |
| f"- {name}: {info.get('rows')} rows; columns={list(info.get('columns', {}).keys())}; " |
| f"date_range={date_range or 'n/a'}" |
| ) |
| dataset_summary = "\n".join(dataset_lines) |
| return f"""You are a question-driven data analysis agent working over local Reddit datasets. |
| |
| Available dataset metadata: |
| {dataset_summary} |
| |
| Current route mode: {route_mode} |
| Route guidance: {route_guidance} |
| |
| Prior analytical state: |
| {conversation_state} |
| |
| Rules: |
| 1. Use the route guidance and only the provided tools. |
| 2. Inspect metadata or row previews before making assumptions when the schema is unclear. |
| 3. Run actual tools for numbers; do not guess. |
| 4. Prefer one minimal reproducible tool path over exploratory tool spam. |
| 5. Distinguish direct findings from caveats. |
| 6. If prior turns already produced a relevant result, reuse that context instead of recomputing unless the user asks for a change. |
| 7. Answer with this structure: direct answer, what was analysed, method, caveats. |
| 8. ALWAYS prefer tools that produce charts (trend_over_time, count_by_group, compare_groups, summary_stats, word_freq) over plain text summaries when the question is quantitative. Every numeric answer should have a chart. |
| 9. For questions about images or visual content, use analyze_image_sample. It reads from raw CSV files with image URLs β no separate setup needed. ALWAYS generate an explicit coding_scheme dict (with label names as keys and definitions as values) before calling this tool β never leave coding_scheme null for a research question. |
| 10. After a large image coding run, offer to run export_reliability_sample to generate a human validation set, then compute_reliability once the user has filled in the human_label column. |
| 11. The dataset covers 30 subreddits including GOONED, GOONEDISBACK, GoonCaves, girlgooners, and more. Use subreddit filters to drill into specific communities.""" |
|
|
|
|
| def run_agent( |
| question: str, |
| history: list[dict] | None = None, |
| turns: list[dict] | None = None, |
| analysis_context: list[dict] | None = None, |
| conversation_state: list[dict] | None = None, |
| ) -> dict: |
| """Run the agent for a user question with deterministic routing and structured prior state.""" |
| try: |
| return _run_agent_inner(question, history, turns, analysis_context, conversation_state) |
| except UnicodeEncodeError as exc: |
| tb = _traceback.format_exc() |
| raise RuntimeError( |
| f"Unicode encoding error (non-ASCII character in data pipeline).\n\n" |
| f"Detail: {exc}\n\nTraceback:\n{tb}" |
| ) from exc |
|
|
|
|
| def _run_agent_inner( |
| question: str, |
| history: list[dict] | None = None, |
| turns: list[dict] | None = None, |
| analysis_context: list[dict] | None = None, |
| conversation_state: list[dict] | None = None, |
| ) -> dict: |
| client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) |
| prior_turns = turns or analysis_context or conversation_state or [] |
| route = route_question(question) |
| available_tools = _tool_subset(route.allowed_tools) |
|
|
| safe_history = [_safe_str(msg) for msg in (history or [])] |
| messages = safe_history |
| messages.append({"role": "user", "content": _safe_str(question)}) |
|
|
| tool_calls_log = [] |
| plotly_jsons = [] |
| total_input_tokens = 0 |
| total_output_tokens = 0 |
| system = _system_prompt( |
| route_mode=route.mode, |
| route_guidance=route.guidance, |
| conversation_state=_conversation_state_summary(prior_turns), |
| ) |
|
|
| while True: |
| safe_messages = _safe_str(messages) |
| safe_system = _safe_str(system) |
| try: |
| response = client.messages.create( |
| model=MODEL, |
| max_tokens=4096, |
| system=safe_system, |
| tools=available_tools, |
| messages=safe_messages, |
| ) |
| except UnicodeEncodeError: |
| stripped_messages = json.loads(json.dumps(safe_messages, default=str, ensure_ascii=True)) |
| stripped_system = safe_system.encode("ascii", errors="ignore").decode("ascii") |
| response = client.messages.create( |
| model=MODEL, |
| max_tokens=4096, |
| system=stripped_system, |
| tools=available_tools, |
| messages=stripped_messages, |
| ) |
|
|
| text_parts = [block.text for block in response.content if block.type == "text"] |
| tool_use_blocks = [block for block in response.content if block.type == "tool_use"] |
|
|
| total_input_tokens += getattr(response.usage, "input_tokens", 0) or 0 |
| total_output_tokens += getattr(response.usage, "output_tokens", 0) or 0 |
|
|
| if response.stop_reason == "end_turn" or not tool_use_blocks: |
| |
| cost_usd = (total_input_tokens / 1_000_000 * 15.0) + (total_output_tokens / 1_000_000 * 75.0) |
| return { |
| "answer": "\n".join(text_parts).strip(), |
| "tool_calls": tool_calls_log, |
| "plotly_json": plotly_jsons[-1] if plotly_jsons else None, |
| "plotly_jsons": plotly_jsons, |
| "route": route.mode, |
| "allowed_tools": _tool_names(available_tools), |
| "usage": { |
| "input_tokens": total_input_tokens, |
| "output_tokens": total_output_tokens, |
| "cost_usd": round(cost_usd, 4), |
| }, |
| } |
|
|
| tool_results = [] |
| for block in tool_use_blocks: |
| fn = TOOL_FN_MAP.get(block.name) |
| if fn is None: |
| result = {"error": f"Unknown tool: {block.name}"} |
| else: |
| try: |
| result = fn(block.input) |
| if isinstance(result, dict) and result.get("plotly_json"): |
| plotly_jsons.append(result["plotly_json"]) |
| except Exception as exc: |
| result = {"error": str(exc)} |
|
|
| safe_result = _safe_str(result) |
| tool_calls_log.append({"tool": block.name, "args": block.input, "result": result}) |
| tool_results.append({ |
| "type": "tool_result", |
| "tool_use_id": block.id, |
| "content": json.dumps(safe_result, default=str, ensure_ascii=True), |
| }) |
|
|
| assistant_content = [_safe_str(block.model_dump()) for block in response.content] |
| messages.append({"role": "assistant", "content": assistant_content}) |
| messages.append({"role": "user", "content": tool_results}) |
|
|