diff --git a/Dockerfile b/Dockerfile index 937b85ef2ae73fbf3e8bd60af6b5f0106d2cddb9..ed2b2a47819ba89b7cbaa251f4549e7a1d0cbd0e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,28 +1,63 @@ -# DataForge Playground - Multi-stage Docker build for HF Spaces. -FROM python:3.12-slim AS builder -WORKDIR /build -RUN apt-get update && \ - apt-get install -y --no-install-recommends gcc g++ && \ - rm -rf /var/lib/apt/lists/* -COPY playground/api/requirements.txt /build/requirements.txt -RUN pip install --no-cache-dir -r /build/requirements.txt -COPY pyproject.toml /build/dataforge_src/pyproject.toml -COPY README_MAIN.md /build/dataforge_src/README.md -COPY dataforge/ /build/dataforge_src/dataforge/ -COPY constitutions/ /build/dataforge_src/constitutions/ -RUN pip install --no-cache-dir /build/dataforge_src - -FROM python:3.12-slim -RUN useradd -m -u 1000 user -COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages -COPY --from=builder /usr/local/bin /usr/local/bin -COPY --from=builder /build/dataforge_src/constitutions /usr/local/lib/python3.12/site-packages/constitutions -COPY playground/api/app.py /home/user/app/app.py -COPY playground/api/samples/ /home/user/app/samples/ -COPY playground/web/ /home/user/app/web/ -USER user -WORKDIR /home/user/app -EXPOSE 7860 -ENV PORT=7860 -ENV DATAFORGE_PLAYGROUND_DEV=0 -CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1", "--timeout-keep-alive", "5"] +# DataForge Playground — Multi-stage Docker build for HF Spaces. +# +# Target: <= 600 MB image. Runs as non-root UID 1000 (HF requirement). +# Single-worker uvicorn with --timeout-keep-alive 5 (slowloris mitigation). +# +# See specs/SPEC_playground.md §4 and §6.5. + +# ============================================================ +# Stage 1: builder — install all Python dependencies +# ============================================================ +FROM python:3.12-slim AS builder + +WORKDIR /build + +# System deps for building wheels +RUN apt-get update && \ + apt-get install -y --no-install-recommends gcc g++ && \ + rm -rf /var/lib/apt/lists/* + +# Install playground API requirements +COPY playground/api/requirements.txt /build/requirements.txt +RUN pip install --no-cache-dir -r /build/requirements.txt + +# Copy dataforge source and install it +COPY pyproject.toml /build/dataforge_src/pyproject.toml +COPY README.md /build/dataforge_src/README.md +COPY dataforge/ /build/dataforge_src/dataforge/ +COPY constitutions/ /build/dataforge_src/constitutions/ +RUN pip install --no-cache-dir /build/dataforge_src + +# ============================================================ +# Stage 2: runtime — minimal image with only installed packages +# ============================================================ +FROM python:3.12-slim + +# HF Spaces requires non-root user with UID 1000 +RUN useradd -m -u 1000 user + +# Copy installed Python packages from builder +COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin + +# Copy constitutions to the site-packages-relative path SafetyFilter expects. +COPY --from=builder /build/dataforge_src/constitutions /usr/local/lib/python3.12/site-packages/constitutions + +# Copy application code +COPY playground/api/app.py /home/user/app/app.py +COPY playground/api/samples/ /home/user/app/samples/ + +# Switch to non-root user +USER user +WORKDIR /home/user/app + +# Expose the port HF Spaces expects +EXPOSE 7860 + +# Environment +ENV PORT=7860 +ENV DATAFORGE_PLAYGROUND_DEV=0 + +# Start uvicorn with single worker (slowapi in-memory limiter contract) +# and honor PORT for Hugging Face runtime assignment. +CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860} --workers 1 --timeout-keep-alive 5"] diff --git a/README.md b/README.md index 40775afebb28b587589cecbe9cae4433d45bc2d6..3cd3f79adb54626187cea35e6af60f8c5dc0abc4 100644 --- a/README.md +++ b/README.md @@ -7,37 +7,40 @@ sdk: docker app_port: 7860 pinned: false license: apache-2.0 -short_description: Upload a CSV, profile and dry-run-repair it in your browser. +short_description: Profile CSVs and dry-run safe repairs. --- -# DataForge Playground +# DataForge Playground API -Upload a CSV file and instantly profile it for data-quality issues or -preview proposed repairs — all in your browser, no installation required. +This is the API backend for the DataForge playground. The browser UI is deployed +separately through Cloudflare Workers Static Assets; this Hugging Face Docker +Space serves stateless CSV profiling and dry-run repair endpoints. -**What it does:** +## What It Does -- **Profile**: Detects type mismatches, decimal shifts, and functional - dependency violations using heuristic detectors. -- **Repair (Dry Run)**: Proposes fixes through the full Safety → Verifier → - Transaction pipeline, returning an ephemeral transaction journal. +- Profile: detects type mismatches, decimal shifts, and functional dependency + violations. +- Repair dry run: proposes fixes through SafetyFilter -> SMTVerifier and + returns an ephemeral transaction receipt without persisting user data. +- Samples: serves small deterministic CSV examples for the static frontend. -**What it does NOT do:** +## What It Does Not Do -- No data is persisted. Your file is processed in memory and discarded. -- No cookies, no analytics of file contents. -- No LLM calls by default (opt-in only, requires a configured key). +- It does not persist uploaded files. +- It does not use cookies or analytics for file contents. +- It does not call an LLM by default. +- It does not perform autonomous production repair. -## Run locally instead +## Run Locally ```bash -pip install dataforge -dataforge profile your_data.csv -dataforge repair your_data.csv --dry-run +python -m pip install -e ".[dev]" +pip install -r playground/api/requirements.txt +uvicorn playground.api.app:app --reload --port 7860 ``` ## Source -- Main repository: [github.com/Praneshrajan15/data-quality-env](https://github.com/Praneshrajan15/data-quality-env) +- Main repository: `github.com/Praneshrajan15/data-quality-env` - Spec: `specs/SPEC_playground.md` - License: Apache-2.0 diff --git a/README_MAIN.md b/README_MAIN.md deleted file mode 100644 index babb6dae8e869f9ffddc6f174524eb2a9c9376c3..0000000000000000000000000000000000000000 --- a/README_MAIN.md +++ /dev/null @@ -1,96 +0,0 @@ -# DataForge - -DataForge currently ships a real Week 3 CLI for CSV profiling and repair. - -This repository now includes shipped detectors, deterministic repairers, -constitutional safety gating, SMT-backed structural verification, reversible -transaction logs, and real-world benchmark infrastructure. The hosted -playground, warehouse integrations, and trained model family remain future -work. - -## Current Status - -- `dataforge profile`, `dataforge repair`, `dataforge revert`, and `dataforge bench` -- Three shipped detectors: `type_mismatch`, `decimal_shift`, `fd_violation` -- Three shipped repairers with safety + verifier gating in the apply path -- Reversible transaction logs with byte-identical revert via source snapshots -- Benchmark/report generation infrastructure for Hospital / Flights / Beers -- `Makefile` targets for setup, lint, type-checking, and tests -- CI plus unit / integration / property / adversarial coverage - -## Benchmark Results - - -Generated from `eval/results/agent_comparison.json`. - -| Method | Precision | Recall | F1 | Avg Steps | Quota Units | -| --- | --- | --- | --- | --- | --- | -| heuristic | 0.0000 | 0.0000 | 0.0000 | 134.33 | 0.0000 | -| llm_react | Skipped | Skipped | Skipped | Skipped | Skipped | -| llm_zeroshot | Skipped | Skipped | Skipped | Skipped | Skipped | -| random | 0.0038 | 0.0003 | 0.0005 | 150.33 | 0.0000 | - -See `BENCHMARK_REPORT.md` for per-dataset tables, error bars, and citation-only SOTA rows. - -Skipped methods in this run: DATAFORGE_LLM_PROVIDER must be set to groq. - - -## Local Setup - -```bash -make setup -make lint -make type -make test -``` - -Verification works on Linux, macOS, or Windows (with Git Bash as the -shell substrate for GNU Make). Requires Python 3.11 or 3.12 -(`requires-python = ">=3.11,<3.13"`). - -### Windows-specific setup - -```powershell -# Install Python 3.12 and GNU Make if not present -winget install -e --id Python.Python.3.12 -winget install -e --id ezwinports.make - -# Create and activate a project venv -py -3.12 -m venv .venv -.\.venv\Scripts\Activate.ps1 - -# Install dependencies and verify -python -m pip install -e ".[all]" -make lint && make type && make test -``` - -Git for Windows provides the Bash implementation the Makefile uses on Windows. -Do not rely on `C:\Windows\System32\bash.exe` (WSL). - -## Environment Variables - -Future provider keys belong in a root `.env` file that is gitignored and meant -to be loaded with `python-dotenv`. - -- `GROQ_API_KEY` -- `GEMINI_API_KEY` -- `CEREBRAS_API_KEY` -- `OPENROUTER_API_KEY` -- `HF_TOKEN` - -## Repository Docs - -- [.cursor/rules/dataforge.md](.cursor/rules/dataforge.md) — always-applied rules -- [ARCHITECTURE.md](ARCHITECTURE.md) — system diagram and dependency justification -- [DECISIONS.md](DECISIONS.md) — technical decision log -- [CONTRIBUTING.md](CONTRIBUTING.md) — workflow and code standards -- [CLAUDE.md](CLAUDE.md) — living knowledge base for Cursor sessions -- [CURSOR_MASTER.md](CURSOR_MASTER.md) — full context and prompt pack -- [META_CONTEXT.md](META_CONTEXT.md) — meta-context (read before writing code) -- [FILE_STRUCTURE.md](FILE_STRUCTURE.md) — canonical target directory tree -- [SECURITY.md](SECURITY.md) — vulnerability reporting policy -- [specs/SPEC_TEMPLATE.md](specs/SPEC_TEMPLATE.md) — spec template for new modules - -## License - -Apache-2.0. See [LICENSE](LICENSE). diff --git a/dataforge/__init__.py b/dataforge/__init__.py index 8fb397c4fb9b3b3780227792cdfcebbcb05399c6..b185d7f8aee2c81f79ae5de0a49fad7e479981d6 100644 --- a/dataforge/__init__.py +++ b/dataforge/__init__.py @@ -1,5 +1,118 @@ -"""DataForge public package.""" +"""DataForge public package. -__all__ = ["__version__"] +The root package is the stable facade for integration surfaces. Symbols are +resolved lazily so importing :mod:`dataforge` does not eagerly import pandas, +FastAPI-facing helpers, or the SMT stack. +""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from dataforge.cli.common import load_schema, read_csv, schema_from_mapping + from dataforge.detectors import Issue, Schema, Severity, run_all_detectors + from dataforge.engine.repair import ( + CandidateFix, + RepairFailure, + RepairPipelineRequest, + RepairPipelineResult, + RepairReceipt, + VerifiedFix, + run_repair_pipeline, + ) + from dataforge.repair_contract import CONTRACT_VERSION + from dataforge.repairers import ProposedFix + from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict + from dataforge.transactions.log import ( + TransactionAuditReport, + TransactionAuditVerdict, + TransactionLogError, + verify_transaction_log, + ) + from dataforge.transactions.revert import TransactionRevertError, revert_transaction + from dataforge.transactions.txn import CellFix, RepairTransaction + from dataforge.verifier import SMTVerifier, VerificationResult, VerificationVerdict + +__all__ = [ + "CONTRACT_VERSION", + "CandidateFix", + "CellFix", + "Issue", + "ProposedFix", + "RepairFailure", + "RepairPipelineRequest", + "RepairPipelineResult", + "RepairReceipt", + "RepairTransaction", + "SMTVerifier", + "SafetyContext", + "SafetyFilter", + "SafetyResult", + "SafetyVerdict", + "Schema", + "Severity", + "TransactionAuditReport", + "TransactionAuditVerdict", + "TransactionLogError", + "TransactionRevertError", + "VerificationResult", + "VerificationVerdict", + "VerifiedFix", + "__version__", + "load_schema", + "read_csv", + "revert_transaction", + "run_all_detectors", + "run_repair_pipeline", + "schema_from_mapping", + "verify_transaction_log", +] __version__ = "0.1.0" + +_PUBLIC_EXPORTS: dict[str, tuple[str, str]] = { + "CONTRACT_VERSION": ("dataforge.repair_contract", "CONTRACT_VERSION"), + "CandidateFix": ("dataforge.engine.repair", "CandidateFix"), + "CellFix": ("dataforge.transactions.txn", "CellFix"), + "Issue": ("dataforge.detectors", "Issue"), + "ProposedFix": ("dataforge.repairers", "ProposedFix"), + "RepairFailure": ("dataforge.engine.repair", "RepairFailure"), + "RepairPipelineRequest": ("dataforge.engine.repair", "RepairPipelineRequest"), + "RepairPipelineResult": ("dataforge.engine.repair", "RepairPipelineResult"), + "RepairReceipt": ("dataforge.engine.repair", "RepairReceipt"), + "RepairTransaction": ("dataforge.transactions.txn", "RepairTransaction"), + "SMTVerifier": ("dataforge.verifier", "SMTVerifier"), + "SafetyContext": ("dataforge.safety", "SafetyContext"), + "SafetyFilter": ("dataforge.safety", "SafetyFilter"), + "SafetyResult": ("dataforge.safety", "SafetyResult"), + "SafetyVerdict": ("dataforge.safety", "SafetyVerdict"), + "Schema": ("dataforge.detectors", "Schema"), + "Severity": ("dataforge.detectors", "Severity"), + "TransactionAuditReport": ("dataforge.transactions.log", "TransactionAuditReport"), + "TransactionAuditVerdict": ("dataforge.transactions.log", "TransactionAuditVerdict"), + "TransactionLogError": ("dataforge.transactions.log", "TransactionLogError"), + "TransactionRevertError": ("dataforge.transactions.revert", "TransactionRevertError"), + "VerificationResult": ("dataforge.verifier", "VerificationResult"), + "VerificationVerdict": ("dataforge.verifier", "VerificationVerdict"), + "VerifiedFix": ("dataforge.engine.repair", "VerifiedFix"), + "load_schema": ("dataforge.cli.common", "load_schema"), + "read_csv": ("dataforge.cli.common", "read_csv"), + "revert_transaction": ("dataforge.transactions.revert", "revert_transaction"), + "run_all_detectors": ("dataforge.detectors", "run_all_detectors"), + "run_repair_pipeline": ("dataforge.engine.repair", "run_repair_pipeline"), + "schema_from_mapping": ("dataforge.cli.common", "schema_from_mapping"), + "verify_transaction_log": ("dataforge.transactions.log", "verify_transaction_log"), +} + + +def __getattr__(name: str) -> Any: + """Resolve public facade exports on first use.""" + try: + module_name, attribute_name = _PUBLIC_EXPORTS[name] + except KeyError as exc: + raise AttributeError(name) from exc + value = getattr(import_module(module_name), attribute_name) + globals()[name] = value + return value diff --git a/dataforge/agent/__init__.py b/dataforge/agent/__init__.py index 0f0eb813f65e1ff2ab9b40976902010d40a5b2bb..d4e690095fd9af8ed93c2c6377ca9f4dcc56e73f 100644 --- a/dataforge/agent/__init__.py +++ b/dataforge/agent/__init__.py @@ -1 +1,16 @@ -"""Agent package scaffolding for DataForge.""" +"""DataForge agent package — typed tool-use actions and scratchpad. + +Public API: + parse_action — Parse raw dict into typed Action model. + Action — Discriminated union of all action types. + Scratchpad — In-episode hypothesis tracker. +""" + +from dataforge.agent.scratchpad import Scratchpad +from dataforge.agent.tool_actions import Action, parse_action + +__all__ = [ + "Action", + "Scratchpad", + "parse_action", +] diff --git a/dataforge/agent/providers.py b/dataforge/agent/providers.py index 4cd3da0b2388770cad82c648c417ec5676e6cd0c..4c428b176febbf35e23cf061544c7e499219f5f7 100644 --- a/dataforge/agent/providers.py +++ b/dataforge/agent/providers.py @@ -59,8 +59,9 @@ def get_provider_name() -> str: """Read the active provider from the environment. Returns: - The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``, - defaulting to ``"groq"`` if not set. + The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``. + When no explicit provider is configured, prefer a provider whose + credential is present in the environment. Example: >>> import os @@ -68,7 +69,14 @@ def get_provider_name() -> str: >>> get_provider_name() 'gemini' """ - return os.environ.get("DATAFORGE_LLM_PROVIDER", "groq").lower() + configured = os.environ.get("DATAFORGE_LLM_PROVIDER") + if configured: + return configured.lower() + if os.environ.get("GROQ_API_KEY"): + return "groq" + if os.environ.get("GEMINI_API_KEY"): + return "gemini" + return "groq" async def complete( diff --git a/dataforge/agent/scratchpad.py b/dataforge/agent/scratchpad.py new file mode 100644 index 0000000000000000000000000000000000000000..1d90468ba304a77aa2e1d8b54c2ed40e04828ec1 --- /dev/null +++ b/dataforge/agent/scratchpad.py @@ -0,0 +1,183 @@ +"""In-episode hypothesis and issue tracker for the DataForge RL agent. + +The scratchpad is a mutable, episode-scoped data structure that the agent +uses to record hypotheses, confirmed issues, and dead ends. The environment +exposes a compact summary of the scratchpad in each observation, enabling +the agent to reason about its investigation history without direct access +to the underlying data structure. + +Example:: + + >>> from dataforge.agent.scratchpad import Scratchpad + >>> pad = Scratchpad() + >>> pad.add_hypothesis("Rating column has decimal shift", [5], ["rating"], "decimal_shift") + >>> pad.confirm_issue(5, "rating", "decimal_shift") + >>> pad.summary() + 'Hypotheses: 1 (0 pending). Confirmed: 1. Dead ends: 0.' +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +__all__ = [ + "ConfirmedIssue", + "DeadEnd", + "HypothesisRecord", + "Scratchpad", +] + + +@dataclass(frozen=True) +class HypothesisRecord: + """A recorded hypothesis about a data-quality root cause. + + Args: + claim: Textual description of the hypothesis. + affected_rows: Row indices the hypothesis covers. + affected_columns: Column names the hypothesis covers. + root_cause_type: Detector-vocabulary root cause type. + confirmed: Whether the hypothesis was confirmed by ground truth. + """ + + claim: str + affected_rows: tuple[int, ...] + affected_columns: tuple[str, ...] + root_cause_type: str + confirmed: bool = False + + +@dataclass(frozen=True) +class ConfirmedIssue: + """A confirmed data-quality issue at a specific location. + + Args: + row: Zero-indexed row number. + column: Column name. + issue_type: Issue type classification. + """ + + row: int + column: str + issue_type: str + + +@dataclass(frozen=True) +class DeadEnd: + """A recorded dead end — an investigation path that yielded nothing. + + Args: + description: What was tried and why it failed. + step_number: Step at which the dead end was recorded. + """ + + description: str + step_number: int + + +@dataclass +class Scratchpad: + """Mutable in-episode tracker for hypotheses, confirmed issues, and dead ends. + + Reset at the start of each episode. The ``summary()`` method produces a + compact string for inclusion in agent observations. + + Example:: + + >>> pad = Scratchpad() + >>> pad.add_hypothesis("Decimal shift in rating", [5], ["rating"], "decimal_shift") + >>> len(pad.hypotheses) + 1 + """ + + hypotheses: list[HypothesisRecord] = field(default_factory=list) + confirmed_issues: list[ConfirmedIssue] = field(default_factory=list) + dead_ends: list[DeadEnd] = field(default_factory=list) + + def add_hypothesis( + self, + claim: str, + affected_rows: list[int], + affected_columns: list[str], + root_cause_type: str, + ) -> HypothesisRecord: + """Record a new hypothesis. + + Args: + claim: Textual description of the hypothesis. + affected_rows: Row indices the hypothesis covers. + affected_columns: Column names the hypothesis covers. + root_cause_type: Detector-vocabulary root cause type. + + Returns: + The recorded hypothesis. + """ + record = HypothesisRecord( + claim=claim, + affected_rows=tuple(affected_rows), + affected_columns=tuple(affected_columns), + root_cause_type=root_cause_type, + ) + self.hypotheses.append(record) + return record + + def confirm_hypothesis(self, index: int) -> None: + """Mark a hypothesis as confirmed. + + Args: + index: Index into the ``hypotheses`` list. + + Raises: + IndexError: If the index is out of range. + """ + old = self.hypotheses[index] + self.hypotheses[index] = HypothesisRecord( + claim=old.claim, + affected_rows=old.affected_rows, + affected_columns=old.affected_columns, + root_cause_type=old.root_cause_type, + confirmed=True, + ) + + def confirm_issue(self, row: int, column: str, issue_type: str) -> None: + """Record a confirmed issue. + + Args: + row: Zero-indexed row number. + column: Column name. + issue_type: Issue type classification. + """ + self.confirmed_issues.append(ConfirmedIssue(row=row, column=column, issue_type=issue_type)) + + def add_dead_end(self, description: str, step_number: int) -> None: + """Record a dead end. + + Args: + description: What was tried and why it failed. + step_number: Step at which the dead end was recorded. + """ + self.dead_ends.append(DeadEnd(description=description, step_number=step_number)) + + def reset(self) -> None: + """Clear all tracked state for a new episode.""" + self.hypotheses.clear() + self.confirmed_issues.clear() + self.dead_ends.clear() + + def summary(self) -> str: + """Produce a compact summary string for observation embedding. + + Returns: + A one-line summary of scratchpad state. + + Example:: + + >>> Scratchpad().summary() + 'Hypotheses: 0 (0 pending). Confirmed: 0. Dead ends: 0.' + """ + pending = sum(1 for h in self.hypotheses if not h.confirmed) + return ( + f"Hypotheses: {len(self.hypotheses)} ({pending} pending). " + f"Confirmed: {len(self.confirmed_issues)}. " + f"Dead ends: {len(self.dead_ends)}." + ) diff --git a/dataforge/agent/tool_actions.py b/dataforge/agent/tool_actions.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae3cdd3e95418d9a455f870425efc833cccba51 --- /dev/null +++ b/dataforge/agent/tool_actions.py @@ -0,0 +1,343 @@ +"""Typed tool-use action models for the DataForge RL environment. + +This module defines a discriminated union of 8 action types that an RL agent +can submit to the DataForge environment. Each action is a standalone Pydantic +model with its own validation rules, preventing cross-model field pollution. + +The ``parse_action`` function is the single entry point for HTTP handlers +and tests to validate raw action dicts into typed models. + +Action Types: + INSPECT_ROWS — View a slice of the dataset. + SQL_QUERY — Execute read-only SQL against the episode DataFrame. + STAT_TEST — Run a statistical test on a column. + PATTERN_MATCH — Evaluate a regex pattern against column values. + HYPOTHESIS — Record a causal-root claim for credit. + ROOT_CAUSE — Analyze selected detected errors for minimal roots. + DIAGNOSE — Flag a suspected issue at (row, column). + FIX — Propose a corrected value for a diagnosed issue. + +Example:: + + >>> from dataforge.agent.tool_actions import parse_action + >>> action = parse_action({"action_type": "INSPECT_ROWS", "row_indices": [0, 1]}) + >>> action.action_type + 'INSPECT_ROWS' +""" + +from __future__ import annotations + +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, Field, field_validator + +__all__ = [ + "Action", + "Diagnose", + "Fix", + "Hypothesis", + "InspectRows", + "PatternMatch", + "RootCause", + "SqlQuery", + "StatTest", + "parse_action", +] + + +class InspectRows(BaseModel): + """View a slice of dataset rows. + + Args: + action_type: Must be ``"INSPECT_ROWS"``. + row_indices: Zero-indexed row indices to retrieve. At least 1 required. + column_names: Optional column filter. If omitted, all columns returned. + + Example:: + + >>> InspectRows(action_type="INSPECT_ROWS", row_indices=[0, 1, 2]) + """ + + action_type: Literal["INSPECT_ROWS"] + row_indices: list[int] = Field(min_length=1, description="Row indices to inspect (0-indexed).") + column_names: list[str] | None = Field(default=None, description="Optional column filter.") + + @field_validator("row_indices") + @classmethod + def _validate_row_indices(cls, v: list[int]) -> list[int]: + """Validate that all row indices are non-negative.""" + if any(i < 0 for i in v): + raise ValueError("All row indices must be >= 0") + return v + + model_config = {"frozen": True} + + +class SqlQuery(BaseModel): + """Execute read-only SQL against the episode DataFrame via DuckDB. + + Args: + action_type: Must be ``"SQL_QUERY"``. + query: SQL query string. Must be read-only (SELECT only). + + Example:: + + >>> SqlQuery(action_type="SQL_QUERY", query="SELECT * FROM data LIMIT 5") + """ + + action_type: Literal["SQL_QUERY"] + query: str = Field(min_length=1, description="Read-only SQL query.") + + model_config = {"frozen": True} + + +class StatTest(BaseModel): + """Run a statistical test on a dataset column. + + Args: + action_type: Must be ``"STAT_TEST"``. + test_type: One of ``"zscore"``, ``"iqr"``, ``"ks"``. + column: Column name to test. + threshold: Optional threshold override. Defaults vary by test type. + + Example:: + + >>> StatTest(action_type="STAT_TEST", test_type="zscore", column="rating") + """ + + action_type: Literal["STAT_TEST"] + test_type: Literal["zscore", "iqr", "ks"] = Field(description="Statistical test to run.") + column: str = Field(min_length=1, description="Column name to test.") + threshold: float | None = Field(default=None, description="Optional threshold override.") + + model_config = {"frozen": True} + + +class PatternMatch(BaseModel): + """Evaluate a regex pattern against column values. + + Args: + action_type: Must be ``"PATTERN_MATCH"``. + pattern: Regular expression string. + column: Column name to evaluate. + expect_match: If True, report rows that match. If False, report non-matches. + + Example:: + + >>> PatternMatch( + ... action_type="PATTERN_MATCH", + ... pattern=r"^\\d{5}$", + ... column="zip_code", + ... ) + """ + + action_type: Literal["PATTERN_MATCH"] + pattern: str = Field(min_length=1, description="Regex pattern.") + column: str = Field(min_length=1, description="Column name to evaluate.") + expect_match: bool = Field( + default=True, + description="True to report matches, False to report non-matches.", + ) + + model_config = {"frozen": True} + + +class Hypothesis(BaseModel): + """Record a causal-root claim for root-cause credit. + + Args: + action_type: Must be ``"HYPOTHESIS"``. + claim: Textual description of the hypothesized root cause. + affected_rows: Row indices believed to be affected. + affected_columns: Column names believed to be affected. + root_cause_type: Detector-vocabulary root cause type + (e.g., ``"decimal_shift"``, ``"type_mismatch"``). + + Example:: + + >>> Hypothesis( + ... action_type="HYPOTHESIS", + ... claim="Column 'rating' has a decimal shift at row 5", + ... affected_rows=[5], + ... affected_columns=["rating"], + ... root_cause_type="decimal_shift", + ... ) + """ + + action_type: Literal["HYPOTHESIS"] + claim: str = Field(min_length=1, description="Root-cause claim.") + affected_rows: list[int] = Field(min_length=1, description="Affected row indices.") + affected_columns: list[str] = Field(min_length=1, description="Affected column names.") + root_cause_type: str = Field(min_length=1, description="Detector-vocabulary root cause type.") + + @field_validator("affected_rows") + @classmethod + def _validate_affected_rows(cls, v: list[int]) -> list[int]: + """Validate that all affected row indices are non-negative.""" + if any(i < 0 for i in v): + raise ValueError("All affected row indices must be >= 0") + return v + + model_config = {"frozen": True} + + +class RootCause(BaseModel): + """Analyze selected detected errors for minimal causal roots. + + Args: + action_type: Must be ``"ROOT_CAUSE"``. + error_indices: Zero-based indices into the episode's detected issue list. + + Example:: + + >>> RootCause(action_type="ROOT_CAUSE", error_indices=[0, 1]) + """ + + action_type: Literal["ROOT_CAUSE"] + error_indices: list[int] = Field(min_length=1, description="Detected issue indices.") + + @field_validator("error_indices") + @classmethod + def _validate_error_indices(cls, v: list[int]) -> list[int]: + """Validate that all error indices are non-negative.""" + if any(i < 0 for i in v): + raise ValueError("All error indices must be >= 0") + return v + + model_config = {"frozen": True} + + +class Diagnose(BaseModel): + """Flag a suspected data-quality issue at a specific (row, column). + + Args: + action_type: Must be ``"DIAGNOSE"``. + row: Zero-indexed row number. + column: Column name. + issue_type: Issue type from detector vocabulary. + + Example:: + + >>> Diagnose( + ... action_type="DIAGNOSE", + ... row=5, column="rating", + ... issue_type="decimal_shift", + ... ) + """ + + action_type: Literal["DIAGNOSE"] + row: int = Field(ge=0, description="Zero-indexed row number.") + column: str = Field(min_length=1, description="Column name.") + issue_type: str = Field(min_length=1, description="Issue type classification.") + + model_config = {"frozen": True} + + +class Fix(BaseModel): + """Propose a corrected value for a diagnosed issue. + + Args: + action_type: Must be ``"FIX"``. + row: Zero-indexed row number. + column: Column name. + new_value: The corrected cell value as a string. + justification: Explanation of why this fix is correct. + fix_type: How to fix the issue. Defaults to ``"correct_value"``. + + Example:: + + >>> Fix( + ... action_type="FIX", + ... row=5, column="rating", + ... new_value="4.5", + ... justification="Decimal shift: 45.0 should be 4.5", + ... ) + """ + + action_type: Literal["FIX"] + row: int = Field(ge=0, description="Zero-indexed row number.") + column: str = Field(min_length=1, description="Column name.") + new_value: str = Field(description="Corrected cell value.") + justification: str = Field(min_length=1, description="Fix justification.") + fix_type: Literal["correct_value", "delete_row", "impute", "standardize"] = Field( + default="correct_value", description="Fix operation type." + ) + + model_config = {"frozen": True} + + +# ═══════════════════════════════════════════════════════════════════════════ +# Discriminated union and parser +# ═══════════════════════════════════════════════════════════════════════════ + +Action = Annotated[ + InspectRows | SqlQuery | StatTest | PatternMatch | Hypothesis | RootCause | Diagnose | Fix, + Field(discriminator="action_type"), +] +"""Discriminated union of all valid DataForge environment actions.""" + + +def parse_action(raw: dict[str, Any]) -> Action: + """Parse and validate a raw action dict into the appropriate typed model. + + This is the single entry point for HTTP handlers and tests to validate + actions. The ``action_type`` field is used as the discriminator. + + Args: + raw: Dictionary with an ``action_type`` key and action-specific fields. + + Returns: + A validated action model instance. + + Raises: + pydantic.ValidationError: If the action is malformed or invalid. + KeyError: If ``action_type`` is missing. + ValueError: If ``action_type`` is not recognized. + + Example:: + + >>> action = parse_action({"action_type": "INSPECT_ROWS", "row_indices": [0]}) + >>> isinstance(action, InspectRows) + True + """ + from pydantic import TypeAdapter + + adapter: TypeAdapter[Action] = TypeAdapter(Action) + return adapter.validate_python(_normalize_action(raw)) + + +def _normalize_action(raw: dict[str, Any]) -> dict[str, Any]: + """Return a canonical action dictionary from supported external aliases.""" + normalized = dict(raw) + action_type = normalized.get("action_type") + if action_type == "SQL_QUERY" and "sql" in normalized and "query" not in normalized: + normalized["query"] = normalized["sql"] + if action_type == "STAT_TEST" and "test" in normalized and "test_type" not in normalized: + normalized["test_type"] = normalized["test"] + if action_type == "PATTERN_MATCH": + if "regex" in normalized and "pattern" not in normalized: + normalized["pattern"] = normalized["regex"] + if "expect" in normalized and "expect_match" not in normalized: + normalized["expect_match"] = normalized["expect"] == "match" + if action_type == "HYPOTHESIS": + root_column = normalized.get("root_column") + downstream = normalized.get("downstream") + if root_column is not None and "affected_columns" not in normalized: + downstream_columns = downstream if isinstance(downstream, list) else [] + normalized["affected_columns"] = [root_column, *downstream_columns] + if "affected_rows" not in normalized: + normalized["affected_rows"] = [0] + if root_column is not None and "root_cause_type" not in normalized: + normalized["root_cause_type"] = root_column + if ( + action_type == "ROOT_CAUSE" + and "indices" in normalized + and "error_indices" not in normalized + ): + normalized["error_indices"] = normalized["indices"] + if action_type == "FIX": + if "proposed_value" in normalized and "new_value" not in normalized: + normalized["new_value"] = normalized["proposed_value"] + if "justification" not in normalized: + normalized["justification"] = "Agent proposed value via FIX." + return normalized diff --git a/dataforge/bench/core.py b/dataforge/bench/core.py index 1af647b31f5055a0261b54991a872c74bab7ed1f..d9609c9a13f317907f27fd93c2fe29a48fcb9994 100644 --- a/dataforge/bench/core.py +++ b/dataforge/bench/core.py @@ -59,6 +59,7 @@ class SeedBenchmarkResult(BaseModel): prompt_tokens: int = Field(ge=0, default=0) completion_tokens: int = Field(ge=0, default=0) quota_units: float = Field(ge=0.0, default=0.0) + gpu_hours: float = Field(ge=0.0, default=0.0) runtime_s: float = Field(ge=0.0, default=0.0) provider: str | None = None model: str | None = None @@ -85,6 +86,8 @@ class AggregateBenchmarkResult(BaseModel): avg_steps_std: float | None = None quota_units_mean: float | None = None quota_units_std: float | None = None + gpu_hours_mean: float | None = None + gpu_hours_std: float | None = None runtime_s_mean: float | None = None runtime_s_std: float | None = None provider: str | None = None @@ -229,6 +232,7 @@ def aggregate_seed_results( f1_mean, f1_std = _mean_std([row.f1 or 0.0 for row in ok_rows]) avg_steps_mean, avg_steps_std = _mean_std([row.avg_steps or 0.0 for row in ok_rows]) quota_mean, quota_std = _mean_std([row.quota_units for row in ok_rows]) + gpu_hours_mean, gpu_hours_std = _mean_std([row.gpu_hours for row in ok_rows]) runtime_mean, runtime_std = _mean_std([row.runtime_s for row in ok_rows]) aggregates.append( AggregateBenchmarkResult( @@ -248,6 +252,8 @@ def aggregate_seed_results( avg_steps_std=avg_steps_std, quota_units_mean=quota_mean, quota_units_std=quota_std, + gpu_hours_mean=gpu_hours_mean, + gpu_hours_std=gpu_hours_std, runtime_s_mean=runtime_mean, runtime_s_std=runtime_std, provider=ok_rows[0].provider, diff --git a/dataforge/bench/groq_client.py b/dataforge/bench/groq_client.py index 7149733c5fff44d1438f290fbb7b02249756d811..6da410b4465b89fd9f29c564f387e19ba8c436f9 100644 --- a/dataforge/bench/groq_client.py +++ b/dataforge/bench/groq_client.py @@ -1,21 +1,45 @@ -"""Minimal Groq client for benchmark-only LLM baselines.""" +"""Minimal OpenAI-compatible clients for benchmark-only LLM baselines.""" from __future__ import annotations import json +import logging import time from dataclasses import dataclass from typing import cast import httpx -from tenacity import retry, retry_if_exception, stop_after_attempt, wait_fixed + + +class ProviderRequestError(RuntimeError): + """Raised when a provider rejects a benchmark request payload.""" + + +class ProviderRateLimitError(ProviderRequestError): + """Raised when a provider asks us to wait longer than the configured cap.""" def _is_rate_limit_error(exc: BaseException) -> bool: - """Return whether an exception is a Groq 429 response.""" + """Return whether an exception is an HTTP 429 response.""" return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429 +def _is_retryable_provider_error(exc: BaseException) -> bool: + """Return whether an HTTP error is worth retrying for teacher collection.""" + return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code in {429, 503} + + +def _retry_after_s(exc: httpx.HTTPStatusError, *, fallback_s: float) -> float: + """Return provider retry-after delay when present.""" + raw_retry_after = exc.response.headers.get("retry-after") + if raw_retry_after is None: + return fallback_s + try: + return max(float(raw_retry_after), fallback_s) + except ValueError: + return fallback_s + + @dataclass(frozen=True, kw_only=True) class GroqCompletion: """Completion payload plus conservative usage accounting.""" @@ -26,26 +50,50 @@ class GroqCompletion: warnings: tuple[str, ...] -class GroqBenchClient: - """Sequential Groq client with fixed 429 retry and spacing.""" +class OpenAICompatBenchClient: + """Sequential OpenAI-compatible client with fixed 429 retry and spacing.""" def __init__( self, *, api_key: str, - model: str = "llama-3.3-70b-versatile", + model: str, + endpoint: str, + provider: str, min_interval_s: float = 2.0, + max_tokens: int = 512, + max_retries: int = 5, + max_retry_after_s: float = 120.0, + timeout_s: float = 60.0, ) -> None: self._api_key = api_key self._model = model + self._endpoint = endpoint + self._provider = provider self._min_interval_s = min_interval_s + self._max_tokens = max_tokens + self._max_retries = max_retries + self._max_retry_after_s = max_retry_after_s + self._timeout_s = timeout_s self._last_success_at: float | None = None + self._client = httpx.Client( + timeout=self._timeout_s, + headers={ + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + }, + ) @property def model(self) -> str: - """Return the configured Groq model name.""" + """Return the configured provider model name.""" return self._model + @property + def provider(self) -> str: + """Return the configured provider identifier.""" + return self._provider + def _respect_spacing(self) -> None: """Sleep long enough to keep requests sequential with a fixed gap.""" if self._last_success_at is None: @@ -55,33 +103,57 @@ class GroqBenchClient: if remaining > 0: time.sleep(remaining) - @retry( - retry=retry_if_exception(_is_rate_limit_error), - wait=wait_fixed(2), - stop=stop_after_attempt(3), - reraise=True, - ) def _post(self, messages: list[dict[str, str]]) -> dict[str, object]: - """Issue the underlying Groq chat-completions request.""" + """Issue the underlying chat-completions request.""" payload = { "model": self._model, "messages": messages, "temperature": 0.0, + "max_tokens": self._max_tokens, } - with httpx.Client(timeout=60.0) as client: - response = client.post( - "https://api.groq.com/openai/v1/chat/completions", - json=payload, - headers={ - "Authorization": f"Bearer {self._api_key}", - "Content-Type": "application/json", - }, - ) - response.raise_for_status() - return dict(response.json()) + last_rate_limit_error: httpx.HTTPStatusError | None = None + for attempt in range(self._max_retries): + response: httpx.Response | None = None + try: + response = self._client.post( + self._endpoint, + json=payload, + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1: + body = exc.response.text[:500].replace("\n", " ") + raise ProviderRequestError( + f"{self._provider} request rejected with HTTP " + f"{exc.response.status_code}: {body}" + ) from exc + last_rate_limit_error = exc + retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1)) + if retry_s > self._max_retry_after_s: + body = exc.response.text[:500].replace("\n", " ") + raise ProviderRateLimitError( + f"{self._provider} rate limit retry-after {retry_s:.2f}s " + f"exceeds cap {self._max_retry_after_s:.2f}s: {body}" + ) from exc + logging.getLogger("dataforge.bench.groq_client").warning( + "%s_rate_limit attempt=%d retry_after_s=%.2f", + self._provider, + attempt + 1, + retry_s, + ) + time.sleep(retry_s) + continue + except httpx.TimeoutException as exc: + raise TimeoutError( + f"{self._provider} request timed out after {self._timeout_s:.1f} seconds." + ) from exc + return dict(response.json()) + if last_rate_limit_error is not None: + raise last_rate_limit_error + raise RuntimeError(f"{self._provider} request failed without a response.") def complete(self, messages: list[dict[str, str]]) -> GroqCompletion: - """Send one benchmark completion request to Groq.""" + """Send one benchmark completion request to the configured provider.""" self._respect_spacing() payload = self._post(messages) self._last_success_at = time.monotonic() @@ -92,16 +164,223 @@ class GroqBenchClient: completion_tokens = int(usage.get("completion_tokens", 0)) if isinstance(usage, dict) else 0 if not usage: warnings.append("missing_usage_payload") + logging.getLogger("dataforge.bench.groq_client").warning( + "%s_missing_usage_payload", self._provider + ) try: choices = cast(list[dict[str, object]], payload["choices"]) message = cast(dict[str, object], choices[0]["message"]) content = str(message["content"]) except (KeyError, IndexError, TypeError) as exc: - raise ValueError(f"Unexpected Groq response payload: {json.dumps(payload)}") from exc + raise ValueError( + f"Unexpected {self._provider} response payload: {json.dumps(payload)}" + ) from exc return GroqCompletion( text=content, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, warnings=tuple(warnings), ) + + +class GroqBenchClient(OpenAICompatBenchClient): + """Sequential Groq client with fixed 429 retry and spacing.""" + + def __init__( + self, + *, + api_key: str, + model: str = "llama-3.3-70b-versatile", + min_interval_s: float = 2.0, + max_tokens: int = 512, + max_retries: int = 5, + max_retry_after_s: float = 120.0, + timeout_s: float = 60.0, + ) -> None: + super().__init__( + api_key=api_key, + model=model, + endpoint="https://api.groq.com/openai/v1/chat/completions", + provider="groq", + min_interval_s=min_interval_s, + max_tokens=max_tokens, + max_retries=max_retries, + max_retry_after_s=max_retry_after_s, + timeout_s=timeout_s, + ) + + +class CerebrasBenchClient(OpenAICompatBenchClient): + """Sequential Cerebras client with fixed 429 retry and spacing.""" + + def __init__( + self, + *, + api_key: str, + model: str = "qwen-3-235b-a22b-instruct-2507", + min_interval_s: float = 0.5, + max_tokens: int = 512, + max_retries: int = 5, + max_retry_after_s: float = 120.0, + timeout_s: float = 60.0, + ) -> None: + super().__init__( + api_key=api_key, + model=model, + endpoint="https://api.cerebras.ai/v1/chat/completions", + provider="cerebras", + min_interval_s=min_interval_s, + max_tokens=max_tokens, + max_retries=max_retries, + max_retry_after_s=max_retry_after_s, + timeout_s=timeout_s, + ) + + +class GeminiBenchClient: + """Sequential Gemini client adapted to the benchmark completion interface.""" + + def __init__( + self, + *, + api_key: str, + model: str = "gemini-3.1-pro-preview", + min_interval_s: float = 2.0, + max_tokens: int = 512, + max_retries: int = 5, + max_retry_after_s: float = 120.0, + timeout_s: float = 60.0, + ) -> None: + self._api_key = api_key + self._model = model.removeprefix("models/") + self._min_interval_s = min_interval_s + self._max_tokens = max_tokens + self._max_retries = max_retries + self._max_retry_after_s = max_retry_after_s + self._timeout_s = timeout_s + self._last_success_at: float | None = None + self._client = httpx.Client( + timeout=self._timeout_s, + headers={"Content-Type": "application/json"}, + ) + + @property + def model(self) -> str: + """Return the configured Gemini model name.""" + return self._model + + @property + def provider(self) -> str: + """Return the provider identifier.""" + return "gemini" + + def _respect_spacing(self) -> None: + """Sleep long enough to keep requests sequential with a fixed gap.""" + if self._last_success_at is None: + return + elapsed = time.monotonic() - self._last_success_at + remaining = self._min_interval_s - elapsed + if remaining > 0: + time.sleep(remaining) + + def _payload(self, messages: list[dict[str, str]]) -> dict[str, object]: + """Convert OpenAI-style chat messages to Gemini generateContent payload.""" + system_texts: list[str] = [] + contents: list[dict[str, object]] = [] + for message in messages: + role = message.get("role", "user") + content = message.get("content", "") + if role == "system": + system_texts.append(content) + continue + gemini_role = "model" if role == "assistant" else "user" + contents.append({"role": gemini_role, "parts": [{"text": content}]}) + + payload: dict[str, object] = { + "contents": contents, + "generationConfig": { + "temperature": 0.0, + "maxOutputTokens": self._max_tokens, + }, + } + if system_texts: + payload["systemInstruction"] = { + "parts": [{"text": "\n\n".join(system_texts)}], + } + return payload + + def _post(self, messages: list[dict[str, str]]) -> dict[str, object]: + """Issue the underlying Gemini generateContent request.""" + endpoint = ( + f"https://generativelanguage.googleapis.com/v1beta/models/{self._model}:generateContent" + ) + last_rate_limit_error: httpx.HTTPStatusError | None = None + for attempt in range(self._max_retries): + response: httpx.Response | None = None + try: + response = self._client.post( + endpoint, + params={"key": self._api_key}, + json=self._payload(messages), + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1: + body = exc.response.text[:500].replace("\n", " ") + raise ProviderRequestError( + f"gemini request rejected with HTTP {exc.response.status_code}: {body}" + ) from exc + last_rate_limit_error = exc + retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1)) + if retry_s > self._max_retry_after_s: + body = exc.response.text[:500].replace("\n", " ") + raise ProviderRateLimitError( + f"gemini rate limit retry-after {retry_s:.2f}s " + f"exceeds cap {self._max_retry_after_s:.2f}s: {body}" + ) from exc + logging.getLogger("dataforge.bench.groq_client").warning( + "gemini_rate_limit attempt=%d retry_after_s=%.2f", + attempt + 1, + retry_s, + ) + time.sleep(retry_s) + continue + except httpx.TimeoutException as exc: + raise TimeoutError( + f"gemini request timed out after {self._timeout_s:.1f} seconds." + ) from exc + return dict(response.json()) + if last_rate_limit_error is not None: + raise last_rate_limit_error + raise RuntimeError("gemini request failed without a response.") + + def complete(self, messages: list[dict[str, str]]) -> GroqCompletion: + """Send one benchmark completion request to Gemini.""" + self._respect_spacing() + payload = self._post(messages) + self._last_success_at = time.monotonic() + + warnings: list[str] = [] + usage = payload.get("usageMetadata", {}) + prompt_tokens = int(usage.get("promptTokenCount", 0)) if isinstance(usage, dict) else 0 + completion_tokens = ( + int(usage.get("candidatesTokenCount", 0)) if isinstance(usage, dict) else 0 + ) + if not usage: + warnings.append("missing_usage_payload") + logging.getLogger("dataforge.bench.groq_client").warning("gemini_missing_usage_payload") + + try: + candidates = cast(list[dict[str, object]], payload["candidates"]) + content = cast(dict[str, object], candidates[0]["content"]) + parts = cast(list[dict[str, object]], content["parts"]) + text = "".join(str(part.get("text", "")) for part in parts) + except (KeyError, IndexError, TypeError) as exc: + raise ValueError(f"Unexpected gemini response payload: {json.dumps(payload)}") from exc + return GroqCompletion( + text=text, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + warnings=tuple(warnings), + ) diff --git a/dataforge/bench/methods.py b/dataforge/bench/methods.py index ca9bbfaca6d14a4ca99dd10121afd96b9c9d31ce..852a82a75668c49e665b3394d3699b557c2fcfe3 100644 --- a/dataforge/bench/methods.py +++ b/dataforge/bench/methods.py @@ -151,6 +151,40 @@ def _column_stats( return stats +def _strip_json_line_comments(text: str) -> str: + """Remove JavaScript-style line comments outside JSON strings.""" + result: list[str] = [] + in_string = False + escaped = False + index = 0 + while index < len(text): + char = text[index] + next_char = text[index + 1] if index + 1 < len(text) else "" + if in_string: + result.append(char) + if escaped: + escaped = False + elif char == "\\": + escaped = True + elif char == '"': + in_string = False + index += 1 + continue + if char == '"': + in_string = True + result.append(char) + index += 1 + continue + if char == "/" and next_char == "/": + index += 2 + while index < len(text) and text[index] not in "\r\n": + index += 1 + continue + result.append(char) + index += 1 + return "".join(result) + + def _extract_json_object(text: str) -> dict[str, object] | None: """Parse the first JSON object found in an LLM response string.""" stripped = text.strip() @@ -158,6 +192,7 @@ def _extract_json_object(text: str) -> dict[str, object] | None: stripped = stripped.strip("`") if stripped.lower().startswith("json"): stripped = stripped[4:].strip() + stripped = _strip_json_line_comments(stripped) decoder = json.JSONDecoder() for offset, char in enumerate(stripped): if char != "{": diff --git a/dataforge/bench/report.py b/dataforge/bench/report.py index 3821c0122d70cc7ba9802c92130b2a8ba2522257..4a37fe0fb9f8ac8cac82b3ea88d148e864cdc9d0 100644 --- a/dataforge/bench/report.py +++ b/dataforge/bench/report.py @@ -69,13 +69,14 @@ def _aggregate_across_datasets(aggregates: list[AggregateBenchmarkResult]) -> li for method in methods: ok_rows = grouped.get(method, []) if not ok_rows: - rows.append([method, "Skipped", "Skipped", "Skipped", "Skipped", "Skipped"]) + rows.append([method, "Skipped", "Skipped", "Skipped", "Skipped", "Skipped", "Skipped"]) continue p_mean = sum(row.precision_mean or 0.0 for row in ok_rows) / len(ok_rows) r_mean = sum(row.recall_mean or 0.0 for row in ok_rows) / len(ok_rows) f_mean = sum(row.f1_mean or 0.0 for row in ok_rows) / len(ok_rows) step_mean = sum(row.avg_steps_mean or 0.0 for row in ok_rows) / len(ok_rows) quota_mean = sum(row.quota_units_mean or 0.0 for row in ok_rows) / len(ok_rows) + gpu_hours_mean = sum(row.gpu_hours_mean or 0.0 for row in ok_rows) / len(ok_rows) rows.append( [ method, @@ -84,6 +85,7 @@ def _aggregate_across_datasets(aggregates: list[AggregateBenchmarkResult]) -> li f"{f_mean:.4f}", f"{step_mean:.2f}", f"{quota_mean:.4f}", + f"{gpu_hours_mean:.4f}", ] ) return rows @@ -104,15 +106,13 @@ def build_readme_benchmark_block(agent_output: BenchmarkRunOutput, report_path: """Build the generated README benchmark summary block.""" rows = _aggregate_across_datasets(agent_output.aggregates) table = _render_table( - ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units"], + ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units", "GPU Hours"], rows, ) skip_reasons = _collect_skip_reasons(agent_output.aggregates) skip_note = "" if skip_reasons: - skip_note = ( - "\n\nSkipped methods in this run: " + "; ".join(skip_reasons) - ) + skip_note = "\n\nSkipped methods in this run: " + "; ".join(skip_reasons) return ( "Generated from `eval/results/agent_comparison.json`.\n\n" f"{table}\n\n" @@ -140,19 +140,28 @@ def render_benchmark_report( _format_metric(row.f1_mean, row.f1_std), _format_metric(row.avg_steps_mean, row.avg_steps_std), _format_metric(row.quota_units_mean, row.quota_units_std), + _format_metric(row.gpu_hours_mean, row.gpu_hours_std), ] for row in rows ] per_dataset_sections.append( f"### {dataset.title()}\n\n" + _render_table( - ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units"], + [ + "Method", + "Precision", + "Recall", + "F1", + "Avg Steps", + "Quota Units", + "GPU Hours", + ], table_rows, ) ) local_summary = _render_table( - ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units"], + ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units", "GPU Hours"], _aggregate_across_datasets(agent_output.aggregates), ) @@ -179,11 +188,7 @@ def render_benchmark_report( skip_reasons = _collect_skip_reasons(agent_output.aggregates) skip_note = "" if skip_reasons: - skip_note = ( - "\nSkipped methods in this reproduced run: " - + "; ".join(skip_reasons) - + "\n" - ) + skip_note = "\nSkipped methods in this reproduced run: " + "; ".join(skip_reasons) + "\n" method_values = agent_output.metadata.get("methods", []) dataset_values = agent_output.metadata.get("datasets", []) @@ -203,6 +208,7 @@ def render_benchmark_report( f"- Datasets: {', '.join(datasets)}\n" f"- Seeds: {seeds}\n" "- Free-tier quota units: `max(llm_calls / 1000, (prompt_tokens + completion_tokens) / 100000)`\n" + "- GRPO compute cost is reported as free-tier GPU-hours, not dollars.\n" f"{skip_note}\n" "## Cross-Dataset Local Results\n\n" f"{local_summary}\n\n" @@ -216,7 +222,7 @@ def render_benchmark_report( sota_rows, ) + "\n\n## Methodology\n\n" - + "Local rows are reproduced from generated JSON. Citation-only SOTA rows are copied from literature and are not rerun in this repository. Quota units are reported in free-tier fractions rather than dollars.\n" + + "Local rows are reproduced from generated JSON. Citation-only SOTA rows are copied from literature and are not rerun in this repository. LLM quota units are free-tier fractions; GRPO compute cost is GPU-hours, not dollars.\n" ) diff --git a/dataforge/bench/runner.py b/dataforge/bench/runner.py index a04f72794d6b79a8663a9bfe174c0912e32320d0..951a623c08569a6f3360668ae434301965ed8b95 100644 --- a/dataforge/bench/runner.py +++ b/dataforge/bench/runner.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +import sys from pathlib import Path from dotenv import load_dotenv @@ -96,18 +97,21 @@ def run_agent_comparison( output_json: Path, really_run_big_bench: bool, cache_root: Path | None = None, + reproduction_command: str | None = None, ) -> BenchmarkRunOutput: """Run the selected benchmark methods across real-world datasets.""" load_dotenv() _validate_inputs(methods, datasets, seeds) estimated_calls = estimate_llm_calls(methods=methods, datasets=datasets, seeds=seeds) + # Validate call budget before any client instantiation or dataset loads that could + # trigger network access in tests with environment variables set. validate_estimated_calls( estimated_calls=estimated_calls, really_run_big_bench=really_run_big_bench, ) - reproduction_command = _reproduction_command(methods, datasets, seeds) + reproduction_command = reproduction_command or _reproduction_command(methods, datasets, seeds) records: list[SeedBenchmarkResult] = [] loaded_datasets = { dataset_name: load_real_world_dataset(dataset_name, cache_root=cache_root) @@ -116,16 +120,45 @@ def run_agent_comparison( llm_methods_requested = any(method.startswith("llm_") for method in methods) skip_reason = _llm_skip_reason() if llm_methods_requested else None - client = ( - GroqBenchClient(api_key=os.environ["GROQ_API_KEY"]) - if llm_methods_requested and skip_reason is None - else None - ) + client = None + if llm_methods_requested and skip_reason is None: + # Allow env-driven tuning for tiny CI checks. + model = os.environ.get("DATAFORGE_GROQ_MODEL", "llama-3.3-70b-versatile") + try: + min_interval_s = float(os.environ.get("DATAFORGE_GROQ_MIN_INTERVAL_S", "1.0")) + except ValueError: + min_interval_s = 1.0 + try: + timeout_s = float(os.environ.get("DATAFORGE_GROQ_TIMEOUT_S", "30")) + except ValueError: + timeout_s = 30.0 + try: + max_tokens = int(os.environ.get("DATAFORGE_GROQ_MAX_TOKENS", "256")) + except ValueError: + max_tokens = 256 + try: + max_retries = int(os.environ.get("DATAFORGE_GROQ_MAX_RETRIES", "3")) + except ValueError: + max_retries = 3 + client = GroqBenchClient( + api_key=os.environ["GROQ_API_KEY"], + model=model, + min_interval_s=min_interval_s, + max_tokens=max_tokens, + max_retries=max_retries, + timeout_s=timeout_s, + ) for dataset_name in datasets: dataset = loaded_datasets[dataset_name] for method in methods: for seed in range(seeds): + if os.environ.get("DATAFORGE_BENCH_VERBOSE"): + print( + f"[dataforge bench] start method={method} dataset={dataset_name} seed={seed}", + file=sys.stderr, + flush=True, + ) if method == "random": result = run_random_episode(dataset, seed=seed) elif method == "heuristic": @@ -159,6 +192,12 @@ def run_agent_comparison( if method == "heuristic": result = result.model_copy(update={"seed": seed}) records.append(result) + if os.environ.get("DATAFORGE_BENCH_VERBOSE"): + print( + f"[dataforge bench] done method={method} dataset={dataset_name} seed={seed} status={result.status}", + file=sys.stderr, + flush=True, + ) aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results( records, seeds_requested=seeds diff --git a/dataforge/causal/__init__.py b/dataforge/causal/__init__.py index 668002382902b8b09773abc63c62b93f19ff53f3..ae5fcc12b065af58035183b1932d2f5a8f676f71 100644 --- a/dataforge/causal/__init__.py +++ b/dataforge/causal/__init__.py @@ -1 +1,21 @@ -"""Causal analysis package scaffolding for DataForge.""" +"""Causal analysis primitives for DataForge root-cause diagnosis.""" + +from dataforge.causal.dag import CausalDAG, CausalEdge +from dataforge.causal.pc import CausalDiscoveryResult, discover_causal_dag +from dataforge.causal.root_cause import ( + CausalRootCauseAnalyzer, + ErrorEvidence, + RootCauseResult, + minimal_root_set, +) + +__all__ = [ + "CausalDAG", + "CausalDiscoveryResult", + "CausalEdge", + "CausalRootCauseAnalyzer", + "ErrorEvidence", + "RootCauseResult", + "discover_causal_dag", + "minimal_root_set", +] diff --git a/dataforge/causal/dag.py b/dataforge/causal/dag.py new file mode 100644 index 0000000000000000000000000000000000000000..72d4be3b8bceb8b178c4ac531d5a08835237e58b --- /dev/null +++ b/dataforge/causal/dag.py @@ -0,0 +1,174 @@ +"""Column-level causal DAG utilities for root-cause analysis.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import networkx as nx # type: ignore[import-untyped] + +__all__ = ["CausalDAG", "CausalEdge"] + + +@dataclass(frozen=True) +class CausalEdge: + """Metadata for a directed causal edge. + + Args: + source: Source column name. + target: Target column name. + confidence: Confidence in the directed influence, from 0.0 to 1.0. + provenance: Human-readable source of the edge. + """ + + source: str + target: str + confidence: float + provenance: str + + +class CausalDAG: + """Acyclic directed graph whose nodes are dataset columns. + + Args: + nodes: Optional initial column names. + + Example: + >>> dag = CausalDAG(["discount_pct", "order_total"]) + >>> dag.add_edge("discount_pct", "order_total", confidence=0.9, provenance="fd") + >>> dag.is_reachable("discount_pct", "order_total") + True + """ + + def __init__(self, nodes: list[str] | tuple[str, ...] = ()) -> None: + self._graph: nx.DiGraph[Any] = nx.DiGraph() + self._graph.add_nodes_from(nodes) + + @property + def nodes(self) -> tuple[str, ...]: + """Return graph nodes in insertion order.""" + return tuple(str(node) for node in self._graph.nodes) + + @property + def edges(self) -> tuple[CausalEdge, ...]: + """Return directed edges with metadata.""" + result: list[CausalEdge] = [] + for source, target, attrs in self._graph.edges(data=True): + result.append( + CausalEdge( + source=str(source), + target=str(target), + confidence=float(attrs.get("confidence", 0.0)), + provenance=str(attrs.get("provenance", "unknown")), + ) + ) + return tuple(result) + + def add_node(self, column: str) -> None: + """Add a column node if it is not already present. + + Args: + column: Column name. + """ + self._graph.add_node(column) + + def add_edge( + self, + source: str, + target: str, + *, + confidence: float, + provenance: str, + ) -> None: + """Add a directed causal edge while preserving acyclicity. + + Args: + source: Source column name. + target: Target column name. + confidence: Confidence score from 0.0 to 1.0. + provenance: Source of the edge. + + Raises: + ValueError: If the edge is self-referential or creates a cycle. + """ + if source == target: + raise ValueError("Causal DAG does not allow self-edges") + self._graph.add_node(source) + self._graph.add_node(target) + if nx.has_path(self._graph, target, source): + raise ValueError(f"Adding {source!r} -> {target!r} would create a cycle") + bounded = max(0.0, min(1.0, confidence)) + self._graph.add_edge(source, target, confidence=bounded, provenance=provenance) + + def successors(self, column: str) -> tuple[str, ...]: + """Return direct downstream columns for a node. + + Args: + column: Column name. + + Returns: + A tuple of direct successor column names. + """ + if column not in self._graph: + return () + return tuple(str(node) for node in self._graph.successors(column)) + + def is_reachable(self, source: str, target: str) -> bool: + """Return whether target is reachable from source. + + Args: + source: Source column name. + target: Target column name. + + Returns: + True if source equals target or a directed path exists. + """ + if source == target: + return True + if source not in self._graph or target not in self._graph: + return False + return bool(nx.has_path(self._graph, source, target)) + + def path_confidence(self, source: str, target: str) -> float: + """Return the weakest-edge confidence on the shortest path. + + Args: + source: Source column name. + target: Target column name. + + Returns: + Confidence in [0.0, 1.0], or 0.0 when no path exists. + """ + if source == target: + return 1.0 + if not self.is_reachable(source, target): + return 0.0 + path = nx.shortest_path(self._graph, source, target) + confidences = [ + float(self._graph.edges[path[i], path[i + 1]].get("confidence", 0.0)) + for i in range(len(path) - 1) + ] + return min(confidences, default=0.0) + + def minimal_root_columns(self, columns: list[str] | tuple[str, ...]) -> tuple[str, ...]: + """Return selected columns that are not downstream of another selection. + + Args: + columns: Selected error columns. + + Returns: + Minimal root columns in first-seen order. + """ + unique: list[str] = [] + for column in columns: + if column not in unique: + unique.append(column) + + roots: list[str] = [] + for column in unique: + has_upstream = any( + other != column and self.is_reachable(other, column) for other in unique + ) + if not has_upstream: + roots.append(column) + return tuple(roots) diff --git a/dataforge/causal/pc.py b/dataforge/causal/pc.py new file mode 100644 index 0000000000000000000000000000000000000000..e326ee528b83508854c7e28220cca6c7c47f44c4 --- /dev/null +++ b/dataforge/causal/pc.py @@ -0,0 +1,232 @@ +"""PC-based causal DAG discovery with functional-dependency priors.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import pandas as pd +from scipy.stats import chi2_contingency # type: ignore[import-untyped] + +from dataforge.causal.dag import CausalDAG +from dataforge.verifier.schema import Schema + +__all__ = ["CausalDiscoveryResult", "discover_causal_dag"] + + +@dataclass(frozen=True) +class CausalDiscoveryResult: + """Result of causal discovery. + + Args: + dag: Directed acyclic graph over columns. + confidence_report: Column-pair confidence or diagnostic metadata. + warnings: Non-fatal discovery warnings. + """ + + dag: CausalDAG + confidence_report: dict[str, float] = field(default_factory=dict) + warnings: tuple[str, ...] = () + + +def discover_causal_dag( + df: pd.DataFrame, + schema: Schema | None = None, + *, + alpha: float = 0.05, +) -> CausalDiscoveryResult: + """Infer a deterministic causal DAG from tabular data and FD priors. + + Args: + df: Input DataFrame. + schema: Optional declared schema with functional dependencies. + alpha: Significance threshold for independence checks. + + Returns: + CausalDiscoveryResult. A DAG is returned even if PC orientation is + underdetermined; low-confidence edges are tagged as such. + """ + columns = [str(column) for column in df.columns] + dag = CausalDAG(columns) + report: dict[str, float] = {} + warnings: list[str] = [] + + if schema is not None: + for fd in schema.functional_dependencies: + for determinant in fd.determinant: + _try_add_edge( + dag, + determinant, + fd.dependent, + confidence=0.95, + provenance="functional_dependency_prior", + warnings=warnings, + ) + report[f"{determinant}->{fd.dependent}"] = 0.95 + + cleaned = _prepare_for_pc(df) + pc_edges, pc_warning = _run_causal_learn_pc(cleaned.to_numpy(), columns, alpha) + if pc_warning: + warnings.append(pc_warning) + for source, target in pc_edges: + _try_add_edge( + dag, + source, + target, + confidence=0.55, + provenance="causal_learn_pc", + warnings=warnings, + ) + report.setdefault(f"{source}->{target}", 0.55) + + for source, target, confidence in _pairwise_dependency_edges(df, alpha): + _try_add_edge( + dag, + source, + target, + confidence=confidence, + provenance="pairwise_ci_fallback", + warnings=warnings, + ) + report.setdefault(f"{source}->{target}", confidence) + + return CausalDiscoveryResult(dag=dag, confidence_report=report, warnings=tuple(warnings)) + + +def _prepare_for_pc(df: pd.DataFrame) -> pd.DataFrame: + """Return numeric data with no NaN values for causal-learn PC.""" + prepared = pd.DataFrame(index=df.index) + for column in df.columns: + numeric = pd.to_numeric(df[column], errors="coerce") + if numeric.notna().sum() >= max(2, int(0.5 * len(df))): + fill = float(numeric.median()) if numeric.notna().any() else 0.0 + prepared[str(column)] = numeric.fillna(fill) + else: + codes, _ = pd.factorize(df[column].astype("string").fillna(""), sort=True) + prepared[str(column)] = codes.astype(float) + return prepared.fillna(0.0) + + +def _run_causal_learn_pc( + data: np.ndarray[Any, Any], columns: list[str], alpha: float +) -> tuple[list[tuple[str, str]], str | None]: + """Run causal-learn PC and return deterministic directed edges.""" + try: + from causallearn.search.ConstraintBased.PC import pc # type: ignore[import-untyped] + + result = pc(data, alpha=alpha, indep_test="fisherz", stable=True, show_progress=False) + except Exception as exc: + return [], f"causal-learn PC unavailable or failed: {exc}" + + matrix = getattr(getattr(result, "G", None), "graph", None) + if matrix is None: + return [], "causal-learn PC returned no adjacency matrix" + + edges: list[tuple[str, str]] = [] + arr = np.asarray(matrix) + for i, source in enumerate(columns): + for j, target in enumerate(columns): + if i >= j or i >= arr.shape[0] or j >= arr.shape[1]: + continue + if arr[i, j] != 0 or arr[j, i] != 0: + edges.append((source, target)) + return edges, None + + +def _pairwise_dependency_edges(df: pd.DataFrame, alpha: float) -> list[tuple[str, str, float]]: + """Return deterministic low-confidence edges for dependent column pairs.""" + columns = [str(column) for column in df.columns] + edges: list[tuple[str, str, float]] = [] + for i, source in enumerate(columns): + for target in columns[i + 1 :]: + p_value = _pairwise_p_value(df[source], df[target]) + if p_value < alpha: + confidence = max(0.25, min(0.75, 1.0 - p_value)) + edges.append((source, target, round(confidence, 4))) + return edges + + +def _pairwise_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float: + """Return a p-value using categorical, continuous, or mixed tests.""" + left_numeric = pd.to_numeric(left, errors="coerce") + right_numeric = pd.to_numeric(right, errors="coerce") + left_cont = left_numeric.notna().sum() >= max(5, int(0.8 * len(left))) + right_cont = right_numeric.notna().sum() >= max(5, int(0.8 * len(right))) + + if left_cont and right_cont: + return _hsic_p_value( + left_numeric.fillna(left_numeric.median()), right_numeric.fillna(right_numeric.median()) + ) + if not left_cont and not right_cont: + return _chi_squared_p_value(left, right) + return _mutual_information_p_value(left, right) + + +def _chi_squared_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float: + """Return chi-squared independence p-value for categorical pairs.""" + table = pd.crosstab( + left.astype("string").fillna(""), right.astype("string").fillna("") + ) + if table.shape[0] < 2 or table.shape[1] < 2: + return 1.0 + _, p_value, _, _ = chi2_contingency(table) + return float(p_value) + + +def _hsic_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float: + """Return HSIC p-value for continuous pairs, with correlation fallback.""" + x = left.to_numpy(dtype=float).reshape(-1, 1) + y = right.to_numpy(dtype=float).reshape(-1, 1) + try: + from hyppo.independence import Hsic # type: ignore[import-untyped] + + _, p_value = Hsic().test(x, y, reps=100, auto=True) + return float(p_value) + except Exception: + corr = abs(float(np.corrcoef(x[:, 0], y[:, 0])[0, 1])) + return 0.0 if corr > 0.75 else 1.0 + + +def _mutual_information_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float: + """Return a bounded pseudo p-value from binned mutual information.""" + left_codes = _codes(left) + right_codes = _codes(right) + table = pd.crosstab(left_codes, right_codes) + total = float(table.to_numpy().sum()) + if total == 0.0 or table.shape[0] < 2 or table.shape[1] < 2: + return 1.0 + joint = table.to_numpy(dtype=float) / total + px = joint.sum(axis=1, keepdims=True) + py = joint.sum(axis=0, keepdims=True) + expected = px @ py + mask = joint > 0 + mi = float((joint[mask] * np.log(joint[mask] / expected[mask])).sum()) + return float(np.exp(-mi)) + + +def _codes(series: pd.Series[Any]) -> np.ndarray[Any, Any]: + """Return stable integer codes for a mixed-type series.""" + numeric = pd.to_numeric(series, errors="coerce") + if numeric.notna().sum() >= max(5, int(0.8 * len(series))): + return pd.qcut( + numeric.fillna(numeric.median()), q=4, duplicates="drop" + ).cat.codes.to_numpy() + codes, _ = pd.factorize(series.astype("string").fillna(""), sort=True) + return codes + + +def _try_add_edge( + dag: CausalDAG, + source: str, + target: str, + *, + confidence: float, + provenance: str, + warnings: list[str], +) -> None: + """Add an edge or record the cycle warning.""" + try: + dag.add_edge(source, target, confidence=confidence, provenance=provenance) + except ValueError as exc: + warnings.append(str(exc)) diff --git a/dataforge/causal/root_cause.py b/dataforge/causal/root_cause.py new file mode 100644 index 0000000000000000000000000000000000000000..52c8d63765dd4c15f01dc177707b2aba15108dca --- /dev/null +++ b/dataforge/causal/root_cause.py @@ -0,0 +1,193 @@ +"""Minimal root-cause selection over detected errors and a causal DAG.""" + +from __future__ import annotations + +from typing import Any, Protocol + +from pydantic import BaseModel, Field + +from dataforge.causal.dag import CausalDAG + +__all__ = [ + "CausalRootCauseAnalyzer", + "ErrorEvidence", + "RootCauseResult", + "evidence_from_issue", + "minimal_root_set", +] + + +class _IssueLike(Protocol): + """Protocol for objects with row/column issue fields.""" + + row: int + column: str + issue_type: str + + +class ErrorEvidence(BaseModel): + """Column-mapped detected error used for causal root-cause analysis. + + Args: + index: Zero-based error index in the caller's selected issue list. + row: Row index where the error was detected. + column: Column where the error was detected. + issue_type: Machine-readable issue type. + """ + + index: int = Field(ge=0) + row: int = Field(ge=0) + column: str = Field(min_length=1) + issue_type: str = Field(min_length=1) + + model_config = {"frozen": True} + + +class RootCauseResult(BaseModel): + """Structured result returned by the root-cause analyzer. + + Args: + root_indices: Minimal selected error indices. + root_columns: Root columns corresponding to root_indices. + covered_indices: Selected error indices covered by the root set. + confidence: Mean path confidence from roots to covered errors. + explanation: Human-readable explanation of the selected roots. + """ + + root_indices: list[int] + root_columns: list[str] + covered_indices: list[int] + confidence: float + explanation: str + + model_config = {"frozen": True} + + +class CausalRootCauseAnalyzer: + """Compute minimal root causes for selected detected errors. + + Args: + dag: Column-level causal DAG. + + Example: + >>> dag = CausalDAG(["discount_pct", "order_total"]) + >>> dag.add_edge("discount_pct", "order_total", confidence=0.9, provenance="formula") + >>> errors = [ + ... ErrorEvidence(index=0, row=1, column="discount_pct", issue_type="bad"), + ... ErrorEvidence(index=1, row=1, column="order_total", issue_type="bad"), + ... ] + >>> CausalRootCauseAnalyzer(dag).analyze(errors).root_indices + [0] + """ + + def __init__(self, dag: CausalDAG) -> None: + self._dag = dag + + def analyze(self, errors: list[ErrorEvidence] | tuple[ErrorEvidence, ...]) -> RootCauseResult: + """Return the minimal root set for the selected errors. + + Args: + errors: Selected detected errors. + + Returns: + RootCauseResult with roots, coverage, confidence, and explanation. + """ + if not errors: + return RootCauseResult( + root_indices=[], + root_columns=[], + covered_indices=[], + confidence=0.0, + explanation="No errors were supplied.", + ) + + roots: list[ErrorEvidence] = [] + for candidate in errors: + if not self._has_upstream_selected_error(candidate, errors): + roots.append(candidate) + + covered: list[int] = [] + path_confidences: list[float] = [] + for error in errors: + for root in roots: + if root.column == error.column or self._dag.is_reachable(root.column, error.column): + covered.append(error.index) + path_confidences.append(self._dag.path_confidence(root.column, error.column)) + break + + confidence = ( + round(sum(path_confidences) / len(path_confidences), 4) if path_confidences else 0.0 + ) + root_columns = [root.column for root in roots] + return RootCauseResult( + root_indices=[root.index for root in roots], + root_columns=root_columns, + covered_indices=covered, + confidence=confidence, + explanation=self._explain(root_columns, len(covered), len(errors)), + ) + + def _has_upstream_selected_error( + self, + candidate: ErrorEvidence, + errors: list[ErrorEvidence] | tuple[ErrorEvidence, ...], + ) -> bool: + """Return whether another selected error causally precedes candidate.""" + for other in errors: + if other.index == candidate.index: + continue + if other.column == candidate.column and other.index < candidate.index: + return True + if other.column != candidate.column and self._dag.is_reachable( + other.column, candidate.column + ): + return True + return False + + @staticmethod + def _explain(root_columns: list[str], covered_count: int, total_count: int) -> str: + """Build a compact result explanation.""" + if not root_columns: + return "No minimal roots were found." + joined = ", ".join(root_columns) + return f"Selected {joined} as minimal roots covering {covered_count}/{total_count} errors." + + +def minimal_root_set( + errors: list[ErrorEvidence] | tuple[ErrorEvidence, ...], dag: CausalDAG +) -> RootCauseResult: + """Convenience wrapper for CausalRootCauseAnalyzer. + + Args: + errors: Selected detected errors. + dag: Column-level causal DAG. + + Returns: + Minimal root-cause result. + """ + return CausalRootCauseAnalyzer(dag).analyze(errors) + + +def evidence_from_issue(index: int, issue: _IssueLike | dict[str, Any]) -> ErrorEvidence: + """Build ErrorEvidence from an Issue-like object or dictionary. + + Args: + index: Error index to assign. + issue: Object or dictionary with row/column/type fields. + + Returns: + ErrorEvidence instance. + """ + if isinstance(issue, dict): + return ErrorEvidence( + index=index, + row=int(issue.get("row", 0)), + column=str(issue.get("column", "")), + issue_type=str(issue.get("type", issue.get("issue_type", "unknown"))), + ) + return ErrorEvidence( + index=index, + row=int(issue.row), + column=str(issue.column), + issue_type=str(issue.issue_type), + ) diff --git a/dataforge/cli/__init__.py b/dataforge/cli/__init__.py index 586727e772aba9ad6f9bf270bacd1406bb638216..df5d81d9a669340f19cfd043ea2cce1d6a92cd18 100644 --- a/dataforge/cli/__init__.py +++ b/dataforge/cli/__init__.py @@ -1,4 +1,4 @@ -"""Typer application entrypoint for DataForge. +"""Typer application entrypoint for DataForge15. Each CLI subcommand is defined in its own module under ``dataforge.cli.*`` and registered here. The ``app`` object is the entry point referenced by @@ -7,13 +7,16 @@ and registered here. The ``app`` object is the entry point referenced by import typer +from dataforge.cli.audit import audit from dataforge.cli.bench import bench from dataforge.cli.profile import profile from dataforge.cli.repair import repair +from dataforge.cli.release import release_app from dataforge.cli.revert import revert +from dataforge.cli.watch import watch app: typer.Typer = typer.Typer( - help="DataForge — AI-powered data-quality detection and repair.", + help="DataForge15 - AI-powered data-quality detection and repair.", no_args_is_help=True, ) @@ -28,15 +31,18 @@ def _main( is_eager=True, ), ) -> None: - """DataForge — AI-powered data-quality detection and repair.""" + """DataForge15 - AI-powered data-quality detection and repair.""" if version: from dataforge import __version__ - typer.echo(f"dataforge {__version__}") + typer.echo(f"dataforge15 {__version__}") raise typer.Exit() app.command(name="profile")(profile) app.command(name="repair")(repair) app.command(name="revert")(revert) +app.command(name="audit")(audit) app.command(name="bench")(bench) +app.command(name="watch")(watch) +app.add_typer(release_app, name="release") diff --git a/dataforge/cli/audit.py b/dataforge/cli/audit.py new file mode 100644 index 0000000000000000000000000000000000000000..c651700bb168ab36492dc8ed9c0572506d4032af --- /dev/null +++ b/dataforge/cli/audit.py @@ -0,0 +1,70 @@ +"""CLI subcommand: ``dataforge audit ``.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Annotated + +import typer +from rich.console import Console +from rich.panel import Panel + +from dataforge.transactions import TransactionAuditVerdict, verify_transaction_log + +_console = Console(stderr=True) + + +def audit( + txn_id: Annotated[ + str, + typer.Argument(help="Transaction identifier to audit."), + ], + search_root: Annotated[ + Path | None, + typer.Option( + "--search-root", + help="Root directory used to locate the transaction log.", + exists=True, + file_okay=False, + dir_okay=True, + readable=True, + ), + ] = None, + log_path: Annotated[ + Path | None, + typer.Option( + "--log-path", + help="Explicit JSONL transaction log path.", + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + ), + ] = None, + json_output: Annotated[ + bool, + typer.Option("--json", help="Print the audit report as JSON."), + ] = False, +) -> None: + """Verify a transaction log's local hash chain.""" + report = verify_transaction_log(txn_id, log_path=log_path, search_root=search_root) + if json_output: + typer.echo(json.dumps(report.model_dump(mode="json"), indent=2, sort_keys=True)) + else: + style = "green" if report.verdict == TransactionAuditVerdict.VERIFIED else "red" + body = ( + f"Verdict: [bold]{report.verdict.value}[/bold]\n" + f"Transaction: {report.txn_id or txn_id}\n" + f"Events: {report.event_count}\n" + f"Head SHA-256: {report.head_sha256 or 'n/a'}" + ) + if report.errors: + body += "\n\n" + "\n".join(f"- {error}" for error in report.errors) + _console.print(Panel(body, title="Transaction Audit", style=style)) + + if report.verdict == TransactionAuditVerdict.VERIFIED: + raise typer.Exit(code=0) + if report.verdict == TransactionAuditVerdict.LEGACY_UNVERIFIED: + raise typer.Exit(code=1) + raise typer.Exit(code=2) diff --git a/dataforge/cli/bench.py b/dataforge/cli/bench.py index f686e2c4c572d4f46e57e38ceb844fec9f60510f..ebd8c3b90b2ff00069bb31e4008270cb2466001c 100644 --- a/dataforge/cli/bench.py +++ b/dataforge/cli/bench.py @@ -2,17 +2,18 @@ from __future__ import annotations +import json +from collections.abc import Callable from pathlib import Path -from typing import Annotated +from typing import Annotated, Any import typer from rich.console import Console from rich.panel import Panel from rich.table import Table -from dataforge.bench.runner import run_agent_comparison - _console = Console(stderr=True) +run_agent_comparison: Callable[..., Any] | None = None def _parse_csv_list(raw_value: str) -> list[str]: @@ -21,6 +22,16 @@ def _parse_csv_list(raw_value: str) -> list[str]: return [value for value in values if value] +def _runner() -> Callable[..., Any]: + """Load the benchmark runner lazily so core CLI imports stay lightweight.""" + global run_agent_comparison + if run_agent_comparison is None: + from dataforge.bench.runner import run_agent_comparison as loaded_runner + + run_agent_comparison = loaded_runner + return run_agent_comparison + + def bench( methods: Annotated[ str, @@ -54,10 +65,14 @@ def bench( help="Where to write eval/results/agent_comparison.json.", ), ] = Path("eval/results/agent_comparison.json"), + json_output: Annotated[ + bool, + typer.Option("--json", help="Print benchmark results as JSON."), + ] = False, ) -> None: """Run real-world benchmark methods across cached benchmark datasets.""" try: - output = run_agent_comparison( + output = _runner()( methods=_parse_csv_list(methods), datasets=_parse_csv_list(datasets), seeds=seeds, @@ -74,6 +89,10 @@ def bench( ) raise typer.Exit(code=2) from exc + if json_output: + typer.echo(json.dumps(output.model_dump(mode="json"), indent=2, sort_keys=True)) + return + table = Table(title="DataForge Benchmark Summary") table.add_column("Method") table.add_column("Dataset") diff --git a/dataforge/cli/common.py b/dataforge/cli/common.py index c94ce307c531acca9abe0a0759e52ac6a371d5aa..ff21a0f41876f87ab8f1119971428d2c42080f4b 100644 --- a/dataforge/cli/common.py +++ b/dataforge/cli/common.py @@ -3,13 +3,14 @@ from __future__ import annotations from collections.abc import Iterable +from importlib import resources from pathlib import Path from typing import cast -import pandas as pd import typer import yaml +from dataforge.table import Table, read_csv as read_table_csv from dataforge.verifier.schema import ( AggregateDependency, AggregateLiteral, @@ -18,6 +19,27 @@ from dataforge.verifier.schema import ( Schema, ) +_PACKAGED_DEMO_FIXTURES = { + "fixtures/hospital_10rows.csv": "fixtures/hospital_10rows.csv", + "fixtures/hospital_schema.yaml": "fixtures/hospital_schema.yaml", +} + + +def resolve_cli_path(path: Path) -> Path: + """Resolve a user path, including DataForge's packaged demo fixture aliases.""" + if path.exists(): + return path + + normalized = path.as_posix().replace("\\", "/").lstrip("./") + packaged_name = _PACKAGED_DEMO_FIXTURES.get(normalized) + if packaged_name is None: + return path + + fixture = resources.files("dataforge").joinpath(packaged_name) + if not fixture.is_file(): + return path + return Path(str(fixture)) + def schema_from_mapping(raw_mapping: object) -> Schema: """Build a Schema from a raw YAML mapping-like payload. @@ -149,13 +171,13 @@ def load_schema(schema_path: Path) -> Schema: return schema_from_mapping(raw) -def read_csv(path: Path) -> pd.DataFrame: +def read_csv(path: Path) -> Table: """Read a CSV using conservative string-preserving defaults. Args: path: CSV path. Returns: - A DataFrame with string-preserved values. + A string-preserving DataForge table. """ - return pd.read_csv(path, dtype=str, keep_default_na=False, na_filter=False) + return read_table_csv(path) diff --git a/dataforge/cli/profile.py b/dataforge/cli/profile.py index 0786c540fbee7ca88512e1088f8975c945c52892..e657f18b681ce1da3a43e92badeb80a6d724eaed 100644 --- a/dataforge/cli/profile.py +++ b/dataforge/cli/profile.py @@ -1,31 +1,46 @@ """CLI subcommand: ``dataforge profile [--schema ]``. Reads a CSV file, runs all detectors, and renders detected issues as a -rich-formatted terminal table. Exit code 0 if no UNSAFE issues; 1 otherwise. +rich-formatted terminal table. Diagnostics exit 0 by default; use +``--fail-on`` for CI gating. """ from __future__ import annotations +import json +from collections.abc import Sequence from pathlib import Path -from typing import Annotated +from typing import Annotated, Literal import typer from rich.console import Console -from dataforge.cli.common import load_schema, read_csv +from dataforge.cli.common import load_schema, read_csv, resolve_cli_path from dataforge.detectors import run_all_detectors -from dataforge.detectors.base import Schema, Severity +from dataforge.detectors.base import Issue, Schema, Severity from dataforge.ui.profile_view import render_profile_table _console = Console(stderr=True) +FailOn = Literal["never", "unsafe", "review", "any"] + + +def _should_fail(issues: Sequence[Issue], fail_on: FailOn) -> bool: + """Return whether profile findings should trip the requested CI gate.""" + if fail_on == "never": + return False + if fail_on == "any": + return bool(issues) + severities = [issue.severity for issue in issues] + if fail_on == "unsafe": + return any(severity == Severity.UNSAFE for severity in severities) + return any(severity >= Severity.REVIEW for severity in severities) + def profile( path: Annotated[ Path, typer.Argument( - exists=True, - readable=True, help="Path to the CSV file to profile.", ), ], @@ -33,22 +48,36 @@ def profile( Path | None, typer.Option( "--schema", - exists=True, - readable=True, help="Path to a YAML schema file with column types and FDs.", ), ] = None, + json_output: Annotated[ + bool, + typer.Option("--json", help="Print profile results as JSON."), + ] = False, + fail_on: Annotated[ + FailOn, + typer.Option( + "--fail-on", + help="Exit 1 when findings meet this threshold: never, unsafe, review, any.", + ), + ] = "never", ) -> None: """Profile a CSV file for data-quality issues. Reads the CSV, runs all detectors (type_mismatch, decimal_shift, fd_violation), and renders a rich-formatted table of detected issues. - Exit code 0 if no UNSAFE issues are found; 1 if any UNSAFE issues exist. + Exit code 0 unless ``--fail-on`` is set and matching findings are present. """ + resolved_path = resolve_cli_path(path) + if not resolved_path.exists(): + _console.print(f"[bold red]CSV file not found:[/bold red] {path}") + raise typer.Exit(code=2) + # Load the CSV with dtype=str to avoid pandas type-coercion artifacts. try: - df = read_csv(path) + df = read_csv(resolved_path) except Exception as exc: _console.print(f"[bold red]Error reading CSV:[/bold red] {exc}") raise typer.Exit(code=2) from exc @@ -56,16 +85,32 @@ def profile( # Optionally load schema. parsed_schema: Schema | None = None if schema is not None: - parsed_schema = load_schema(schema) + resolved_schema = resolve_cli_path(schema) + if not resolved_schema.exists(): + _console.print(f"[bold red]Schema file not found:[/bold red] {schema}") + raise typer.Exit(code=2) + parsed_schema = load_schema(resolved_schema) # Run all detectors. issues = run_all_detectors(df, parsed_schema) # Render the results. - output_console = Console() - render_profile_table(issues, output_console, file_path=str(path)) + if json_output: + typer.echo( + json.dumps( + { + "path": str(resolved_path), + "issues_count": len(issues), + "fail_on": fail_on, + "issues": [issue.model_dump(mode="json") for issue in issues], + }, + indent=2, + sort_keys=True, + ) + ) + else: + output_console = Console() + render_profile_table(issues, output_console, file_path=str(resolved_path)) - # Exit code based on UNSAFE issues. - has_unsafe = any(i.severity == Severity.UNSAFE for i in issues) - if has_unsafe: + if _should_fail(issues, fail_on): raise typer.Exit(code=1) diff --git a/dataforge/cli/release.py b/dataforge/cli/release.py new file mode 100644 index 0000000000000000000000000000000000000000..080bb646b80d2fc8baa53f8b3f2e6c8b353885e5 --- /dev/null +++ b/dataforge/cli/release.py @@ -0,0 +1,39 @@ +"""CLI group for local release verification.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Annotated + +import typer + +from dataforge.release.doctor import DEFAULT_KAGGLE_CREDENTIALS, run_doctor + +release_app = typer.Typer(help="Release verification utilities.", no_args_is_help=True) + + +@release_app.command(name="doctor") +def doctor( + json_output: Annotated[ + bool, + typer.Option("--json", help="Print machine-readable JSON."), + ] = False, + kaggle_credentials: Annotated[ + Path, + typer.Option( + "--kaggle-credentials", + help="Path to Kaggle OAuth credentials.json. Legacy kaggle.json is never read.", + ), + ] = DEFAULT_KAGGLE_CREDENTIALS, +) -> None: + """Verify local release/deploy auth without printing secrets.""" + report = run_doctor(kaggle_credentials=kaggle_credentials) + if json_output: + typer.echo(json.dumps(report.to_dict(), indent=2, sort_keys=True)) + else: + for check in report.checks: + status = "ok" if check.ok else "fail" + typer.echo(f"{status:4} {check.name}: {check.detail}") + raise typer.Exit(code=0 if report.ok else 2) + diff --git a/dataforge/cli/repair.py b/dataforge/cli/repair.py index dc224cfff744a08bcfd81e246a14b13ed4c170d7..a386f7402cc5c3306d44ad2bee92cf3c4726b4b6 100644 --- a/dataforge/cli/repair.py +++ b/dataforge/cli/repair.py @@ -2,32 +2,25 @@ from __future__ import annotations -import hashlib -from datetime import UTC, datetime +import json from pathlib import Path -from typing import Annotated +from typing import TYPE_CHECKING, Annotated -import pandas as pd import typer from rich.console import Console from rich.panel import Panel -from dataforge.cli.common import load_schema, read_csv -from dataforge.detectors import run_all_detectors +from dataforge.cli.common import load_schema, resolve_cli_path from dataforge.detectors.base import Issue, Schema -from dataforge.repairers import build_repairers -from dataforge.repairers.base import ProposedFix, RepairAttempt, RetryContext -from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict -from dataforge.transactions.log import ( - append_applied_event, - append_created_transaction, - cache_dir_for, - sha256_bytes, - snapshot_path_for, -) -from dataforge.transactions.txn import CellFix, RepairTransaction, generate_txn_id +from dataforge.repairers.base import ProposedFix, RepairAttempt +from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult +from dataforge.transactions.txn import CellFix from dataforge.ui.repair_diff import render_repair_diff -from dataforge.verifier import SMTVerifier, VerificationVerdict + +if TYPE_CHECKING: + import pandas as pd + + from dataforge.engine.repair import RepairPipelineResult _console = Console(stderr=True) @@ -45,32 +38,19 @@ def apply_fixes_to_csv(path: Path, fixes: list[CellFix]) -> str: Raises: ValueError: If a fix references a missing row/column or stale old value. """ - df = read_csv(path) - for fix in fixes: - if fix.operation != "update": - raise ValueError(f"Unsupported repair operation '{fix.operation}' for row {fix.row}.") - if fix.column not in df.columns: - raise ValueError(f"Column '{fix.column}' not found in '{path}'.") - if fix.row < 0 or fix.row >= len(df.index): - raise ValueError(f"Row {fix.row} is out of bounds for '{path}'.") - - current_value = str(df.at[fix.row, fix.column]) - if current_value != fix.old_value: - raise ValueError( - f"Refusing to apply stale fix for row {fix.row}, column '{fix.column}': " - f"expected '{fix.old_value}', found '{current_value}'." - ) - df.at[fix.row, fix.column] = fix.new_value + from dataforge.engine.repair import apply_fixes_to_csv as engine_apply_fixes_to_csv - df.to_csv(path, index=False, lineterminator="\n") - return hashlib.sha256(path.read_bytes()).hexdigest() + return engine_apply_fixes_to_csv(path, fixes) def _resolve_schema(schema_path: Path | None) -> Schema | None: """Resolve an optional schema path into a parsed Schema.""" if schema_path is None: return None - return load_schema(schema_path) + resolved_schema = resolve_cli_path(schema_path) + if not resolved_schema.exists(): + raise typer.BadParameter(f"Schema file '{schema_path}' does not exist.") + return load_schema(resolved_schema) def _print_error(message: str, *, hint: str | None = None) -> None: @@ -94,157 +74,21 @@ def _propose_repairs( confirm_escalations: bool, interactive: bool, ) -> tuple[list[ProposedFix], list[list[RepairAttempt]]]: - """Run repairers and gates issue-by-issue against the working dataframe.""" - repairers = build_repairers( - cache_dir=cache_dir_for(path), + """Compatibility wrapper around the shared repair engine proposal stage.""" + from dataforge.engine.repair import propose_repairs as engine_propose_repairs + + return engine_propose_repairs( + issues, + path, + working_df, + schema, allow_llm=allow_llm, model=model, - ) - safety_filter = SafetyFilter() - verifier = SMTVerifier() - safety_context = SafetyContext( allow_pii=allow_pii, confirm_pii=confirm_pii, confirm_escalations=confirm_escalations, - ) - - accepted_fixes: list[ProposedFix] = [] - attempt_groups: list[list[RepairAttempt]] = [] - - for issue in issues: - attempts: list[RepairAttempt] = [] - repairer = repairers.get(issue.issue_type) - if repairer is None: - attempts.append( - RepairAttempt( - issue=issue, - attempt_number=1, - status="attempted_not_fixed", - reason="No repairer is registered for this issue type.", - ) - ) - attempt_groups.append(attempts) - continue - - accepted = False - retry_context = RetryContext(issue=issue) - for attempt_number in range(1, 4): - candidate = repairer.propose(issue, working_df, schema, retry_context=retry_context) - if candidate is None: - attempts.append( - RepairAttempt( - issue=issue, - attempt_number=attempt_number, - status="attempted_not_fixed", - reason="No repair proposal was available for this issue.", - ) - ) - break - - preferred = safety_filter.choose_preferred([candidate], schema, safety_context) - safety_result = safety_filter.evaluate(preferred, schema, safety_context) - if safety_result.verdict == SafetyVerdict.ESCALATE and interactive: - safety_context, safety_result = _resolve_escalation( - preferred, - schema, - safety_context, - safety_filter, - safety_result, - ) - - if safety_result.verdict == SafetyVerdict.DENY: - attempts.append( - RepairAttempt( - issue=issue, - attempt_number=attempt_number, - fix=preferred, - status="denied", - reason=safety_result.reason, - ) - ) - retry_context = _build_retry_context(issue, attempts) - continue - - if safety_result.verdict == SafetyVerdict.ESCALATE: - attempts.append( - RepairAttempt( - issue=issue, - attempt_number=attempt_number, - fix=preferred, - status="escalated", - reason=safety_result.reason, - ) - ) - break - - verifier_result = verifier.verify(working_df, [preferred], schema) - if verifier_result.verdict == VerificationVerdict.ACCEPT: - accepted_fixes.append(preferred) - working_df.at[preferred.fix.row, preferred.fix.column] = preferred.fix.new_value - attempts.append( - RepairAttempt( - issue=issue, - attempt_number=attempt_number, - fix=preferred, - status="accepted", - reason=verifier_result.reason, - ) - ) - accepted = True - break - - attempts.append( - RepairAttempt( - issue=issue, - attempt_number=attempt_number, - fix=preferred, - status=( - "rejected" - if verifier_result.verdict == VerificationVerdict.REJECT - else "unknown" - ), - reason=verifier_result.reason, - unsat_core=verifier_result.unsat_core, - ) - ) - retry_context = _build_retry_context(issue, attempts) - - if ( - not accepted - and attempts - and attempts[-1].status not in {"attempted_not_fixed", "escalated"} - ): - last_reason = attempts[-1].reason - attempts[-1] = attempts[-1].model_copy( - update={ - "status": "attempted_not_fixed", - "reason": ( - f"Issue was attempted but not fixed after {len(attempts)} attempt(s). " - f"Last failure: {last_reason}" - ), - } - ) - attempt_groups.append(attempts) - - return accepted_fixes, attempt_groups - - -def _build_retry_context(issue: Issue, attempts: list[RepairAttempt]) -> RetryContext: - """Build retry hints from previous failed attempts.""" - rejected_values = frozenset( - attempt.fix.fix.new_value - for attempt in attempts - if attempt.fix is not None and attempt.status in {"denied", "rejected", "unknown"} - ) - hints: list[str] = [] - for attempt in attempts: - hints.append(attempt.reason) - hints.extend(attempt.unsat_core) - return RetryContext( - issue=issue, - previous_attempts=tuple(attempts), - rejected_values=rejected_values, - hints=tuple(hints), + interactive=interactive, + escalation_resolver=_resolve_escalation, ) @@ -309,45 +153,46 @@ def _render_attempt_summary( return len(failed_groups) +def _render_failure_summary(result: RepairPipelineResult, console: Console) -> int: + """Render a summary for issues that the shared engine could not repair.""" + if not result.failures: + return 0 + + console.print("[bold yellow]Attempted But Not Fixed[/bold yellow]") + for failure in result.failures: + prefix = "" + if any(label.startswith("fd::") for label in failure.unsat_core): + prefix = "functional dependency rejection - " + elif any(label.startswith("domain::") for label in failure.unsat_core): + prefix = "domain bound rejection - " + console.print( + f"{failure.issue_type} at {failure.row}:{failure.column} " + f"after {failure.attempt_count} attempt(s): {prefix}{failure.reason}", + overflow="fold", + ) + return len(result.failures) + + +def _json_result(result: RepairPipelineResult) -> str: + """Serialize a repair result for CLI/MCP/CI consumers.""" + return json.dumps(result.model_dump(mode="json"), indent=2, sort_keys=True) + + def _apply_transaction( path: Path, fixes: list[ProposedFix], source_bytes: bytes, ) -> str: - """Write a transaction record, apply fixes, and append the applied event.""" - resolved_path = path.resolve() - txn_id = generate_txn_id() - snapshot_path = snapshot_path_for(resolved_path, txn_id) - snapshot_path.parent.mkdir(parents=True, exist_ok=True) - snapshot_path.write_bytes(source_bytes) - - transaction = RepairTransaction( - txn_id=txn_id, - created_at=datetime.now(UTC), - source_path=str(resolved_path), - source_sha256=sha256_bytes(source_bytes), - source_snapshot_path=str(snapshot_path.resolve()), - fixes=[proposal.fix for proposal in fixes], - applied=False, - ) - log_path = append_created_transaction(transaction) + """Compatibility wrapper around the shared repair engine transaction path.""" + from dataforge.engine.repair import apply_transaction as engine_apply_transaction - try: - post_sha256 = apply_fixes_to_csv(path, [proposal.fix for proposal in fixes]) - append_applied_event(log_path, txn_id, post_sha256=post_sha256) - except Exception: - path.write_bytes(source_bytes) - raise - - return txn_id + return engine_apply_transaction(path, fixes, source_bytes) def repair( path: Annotated[ Path, typer.Argument( - exists=True, - readable=True, help="Path to the CSV file to repair.", ), ], @@ -355,8 +200,6 @@ def repair( Path | None, typer.Option( "--schema", - exists=True, - readable=True, help="Path to a YAML schema file with column types and FDs.", ), ] = None, @@ -400,6 +243,10 @@ def repair( str, typer.Option("--llm-model", help="Model name for fd_violation LLM fallback."), ] = "gemini-2.0-flash", + json_output: Annotated[ + bool, + typer.Option("--json", help="Print repair result as JSON."), + ] = False, ) -> None: """Detect, propose, and optionally apply reversible repairs to a CSV.""" if dry_run == apply: @@ -410,58 +257,66 @@ def repair( raise typer.Exit(code=2) try: + resolved_path = resolve_cli_path(path) + if not resolved_path.exists(): + raise typer.BadParameter(f"CSV file '{path}' does not exist.") parsed_schema = _resolve_schema(schema) - df = read_csv(path) except Exception as exc: _print_error(str(exc)) raise typer.Exit(code=2) from exc - issues = run_all_detectors(df, parsed_schema) - accepted_fixes, attempt_groups = _propose_repairs( - issues, - path, - df.copy(deep=True), - parsed_schema, - allow_llm=allow_llm, - model=llm_model, - allow_pii=allow_pii, - confirm_pii=confirm_pii, - confirm_escalations=confirm_escalations, - interactive=apply, - ) + try: + from dataforge.engine.repair import RepairPipelineRequest, run_repair_pipeline + + result = run_repair_pipeline( + RepairPipelineRequest( + source_path=resolved_path, + mode="apply" if apply else "dry_run", + schema=parsed_schema, + allow_llm=allow_llm, + model=llm_model, + allow_pii=allow_pii, + confirm_pii=confirm_pii, + confirm_escalations=confirm_escalations, + interactive=apply, + ) + ) + except Exception as exc: + _print_error( + f"Failed to apply repairs: {exc}" if apply else f"Failed to repair: {exc}", + hint="The source file was restored to its pre-apply bytes." if apply else None, + ) + raise typer.Exit(code=1 if apply else 2) from exc - output_console = Console() - render_repair_diff(accepted_fixes, output_console, file_path=str(path)) - failed_issue_count = _render_attempt_summary(attempt_groups, output_console) + if json_output: + typer.echo(_json_result(result)) + raise typer.Exit(code=0 if result.fixes else 1) - if not accepted_fixes and failed_issue_count == 0: + output_console = Console() + render_repair_diff(result.fixes, output_console, file_path=str(resolved_path)) + failed_issue_count = _render_failure_summary(result, output_console) + + if not result.fixes and failed_issue_count == 0: + if result.receipt.reason != "No accepted fixes were produced.": + output_console.print( + Panel( + f"[yellow]{result.receipt.reason}[/yellow]", + title="Repair Summary", + style="yellow", + ) + ) raise typer.Exit(code=1) if dry_run: - raise typer.Exit(code=0 if accepted_fixes else 1) - - if not accepted_fixes: - raise typer.Exit(code=1) + raise typer.Exit(code=0 if result.fixes else 1) - batch_safety = SafetyFilter().evaluate_batch(accepted_fixes) - if batch_safety.verdict != SafetyVerdict.ALLOW: - _print_error(batch_safety.reason) + if not result.fixes or not result.receipt.applied: raise typer.Exit(code=1) - source_bytes = path.read_bytes() - try: - txn_id = _apply_transaction(path, accepted_fixes, source_bytes) - except Exception as exc: - _print_error( - f"Failed to apply repairs: {exc}", - hint="The source file was restored to its pre-apply bytes.", - ) - raise typer.Exit(code=1) from exc - output_console.print( Panel( - f"[green]Applied {len(accepted_fixes)} fix(es).[/green]\n" - f"Transaction ID: [bold]{txn_id}[/bold]", + f"[green]Applied {len(result.fixes)} fix(es).[/green]\n" + f"Transaction ID: [bold]{result.receipt.txn_id}[/bold]", title="Repair Applied", style="green", ) diff --git a/dataforge/cli/watch.py b/dataforge/cli/watch.py new file mode 100644 index 0000000000000000000000000000000000000000..359f0a0f09cfcb9754fd9d976260907dd6034f10 --- /dev/null +++ b/dataforge/cli/watch.py @@ -0,0 +1,142 @@ +"""CLI subcommand: ``dataforge watch``.""" + +from __future__ import annotations + +import json +import time +from pathlib import Path +from typing import Annotated, Literal + +import typer +from rich.console import Console +from rich.panel import Panel + +from dataforge.cli.common import load_schema, read_csv, resolve_cli_path +from dataforge.detectors import run_all_detectors +from dataforge.detectors.base import Schema +from dataforge.ui.profile_view import render_profile_table +from dataforge.ui.repair_diff import render_repair_diff + +_console = Console(stderr=True) + +WatchAction = Literal["profile", "repair"] + + +def _load_optional_schema(schema_path: Path | None) -> Schema | None: + if schema_path is None: + return None + resolved_schema = resolve_cli_path(schema_path) + if not resolved_schema.exists(): + raise typer.BadParameter(f"Schema file '{schema_path}' does not exist.") + return load_schema(resolved_schema) + + +def _profile_once(path: Path, schema: Schema | None, json_output: bool) -> None: + df = read_csv(path) + issues = run_all_detectors(df, schema) + if json_output: + typer.echo( + json.dumps( + { + "event": "profile", + "path": str(path), + "issues_count": len(issues), + "issues": [issue.model_dump(mode="json") for issue in issues], + }, + indent=2, + sort_keys=True, + ) + ) + return + render_profile_table(issues, Console(), file_path=str(path)) + + +def _repair_once(path: Path, schema: Schema | None, apply: bool, json_output: bool) -> None: + from dataforge.engine.repair import RepairPipelineRequest, run_repair_pipeline + + result = run_repair_pipeline( + RepairPipelineRequest( + source_path=path, + mode="apply" if apply else "dry_run", + schema=schema, + interactive=False, + ) + ) + if json_output: + payload = result.model_dump(mode="json") + payload["event"] = "repair" + typer.echo(json.dumps(payload, indent=2, sort_keys=True)) + return + render_repair_diff(result.fixes, Console(), file_path=str(path)) + + +def _run_once(path: Path, schema: Schema | None, action: WatchAction, apply: bool, json: bool) -> None: + if action == "repair": + _repair_once(path, schema, apply, json) + else: + _profile_once(path, schema, json) + + +def watch( + path: Annotated[ + Path, + typer.Argument(help="CSV or dbt artifact path to watch."), + ], + schema: Annotated[ + Path | None, + typer.Option("--schema", help="Path to a YAML schema file with column types and FDs."), + ] = None, + action: Annotated[ + WatchAction, + typer.Option("--action", help="Action to run when the file changes: profile or repair."), + ] = "profile", + apply: Annotated[ + bool, + typer.Option("--apply", help="Apply repairs on change. Defaults to dry-run repair."), + ] = False, + interval: Annotated[ + float, + typer.Option("--interval", min=0.1, help="Polling interval in seconds."), + ] = 2.0, + once: Annotated[ + bool, + typer.Option("--once", help="Run once and exit, useful for CI acceptance."), + ] = False, + json_output: Annotated[ + bool, + typer.Option("--json", help="Print watch events as JSON."), + ] = False, +) -> None: + """Poll a path and rerun profile or repair when it changes.""" + resolved_path = resolve_cli_path(path) + if not resolved_path.exists(): + _console.print(f"[bold red]Watch path not found:[/bold red] {path}") + raise typer.Exit(code=2) + parsed_schema = _load_optional_schema(schema) + + if apply and action != "repair": + _console.print( + Panel( + "--apply is only valid with --action repair.", + title="Watch Error", + style="red", + ) + ) + raise typer.Exit(code=2) + + _run_once(resolved_path, parsed_schema, action, apply, json_output) + if once: + return + + last_mtime = resolved_path.stat().st_mtime_ns + while True: + time.sleep(interval) + try: + current_mtime = resolved_path.stat().st_mtime_ns + except FileNotFoundError: + _console.print(f"[bold red]Watch path disappeared:[/bold red] {resolved_path}") + raise typer.Exit(code=2) from None + if current_mtime == last_mtime: + continue + last_mtime = current_mtime + _run_once(resolved_path, parsed_schema, action, apply, json_output) diff --git a/dataforge/datasets/embedded/hospital/clean.csv b/dataforge/datasets/embedded/hospital/clean.csv new file mode 100644 index 0000000000000000000000000000000000000000..fa37b6fbc123421e3a838d79876efc227c94d6cf --- /dev/null +++ b/dataforge/datasets/embedded/hospital/clean.csv @@ -0,0 +1,11 @@ +id,age,admission_date,name +1,30,2020-01-01,Alice +2,45,2020-01-02,Bob +3,30,2020-01-03,Carol +4,29,2020-01-04,Dave +5,35,2020-01-05,Eve +6,51,2020-01-06,Frank +7,40,2020-01-07,Grace +8,35,2020-01-08,Heidi +9,28,2020-01-09,Ivan +10,60,2020-01-10,Judy diff --git a/dataforge/datasets/embedded/hospital/dirty.csv b/dataforge/datasets/embedded/hospital/dirty.csv new file mode 100644 index 0000000000000000000000000000000000000000..049f2bc14f63e3abc0f1a46574289fcd117ce21a --- /dev/null +++ b/dataforge/datasets/embedded/hospital/dirty.csv @@ -0,0 +1,11 @@ +id,age,admission_date,name +1,30,2020-01-01,Alice +2,45,2020-01-02,Bob +3,N/A,2020-01-03,Carol +4,29,2020-01-04,Dave +5,null,2020-01-05,Eve +6,51,2020-01-06,Frank +7,40,2020-01-07,Grace +8,35,2020-01-08,Heidi +9,28,2020-01-09,Ivan +10,60,2020-01-10,Judy diff --git a/dataforge/datasets/real_world.py b/dataforge/datasets/real_world.py index ed06880858eca62b9ec9d75727e8282a8db90a44..9c8b6bfa7fb5ae3698937133019d89a213f86a2d 100644 --- a/dataforge/datasets/real_world.py +++ b/dataforge/datasets/real_world.py @@ -2,6 +2,8 @@ from __future__ import annotations +import logging +import os from dataclasses import dataclass from pathlib import Path @@ -16,6 +18,9 @@ class DatasetDownloadError(RuntimeError): """Raised when a real-world dataset cannot be downloaded or loaded from cache.""" +_LOGGER = logging.getLogger("dataforge.datasets.real_world") + + class GroundTruthCell(BaseModel): """Single cell-level dirty-to-clean correction used for benchmark scoring.""" @@ -57,7 +62,11 @@ def _read_cached_csv(path: Path) -> pd.DataFrame: def _download_bytes(url: str) -> bytes: """Download raw CSV bytes from an upstream source URL.""" - with httpx.Client(timeout=60.0, follow_redirects=True) as client: + try: + timeout = float(os.environ.get("DATAFORGE_DOWNLOAD_TIMEOUT_S", "5")) + except ValueError: + timeout = 5.0 + with httpx.Client(timeout=timeout, follow_redirects=True) as client: response = client.get(url) response.raise_for_status() return response.content @@ -67,8 +76,19 @@ def _download_to_cache(metadata: DatasetMetadata, dataset_dir: Path) -> None: """Download dirty/clean CSV files into the dataset cache directory.""" dataset_dir.mkdir(parents=True, exist_ok=True) dirty_url, clean_url = metadata.source_urls + _LOGGER.info("dataset_download_start name=%s dir=%s", metadata.name, dataset_dir) (dataset_dir / "dirty.csv").write_bytes(_download_bytes(dirty_url)) (dataset_dir / "clean.csv").write_bytes(_download_bytes(clean_url)) + _LOGGER.info("dataset_download_complete name=%s dir=%s", metadata.name, dataset_dir) + + +def _load_embedded_dataset(name: str) -> tuple[pd.DataFrame, pd.DataFrame] | None: + root = Path(__file__).parent / "embedded" / name + dirty_path = root / "dirty.csv" + clean_path = root / "clean.csv" + if not dirty_path.exists() or not clean_path.exists(): + return None + return _read_cached_csv(dirty_path), _read_cached_csv(clean_path) def _manual_download_message(metadata: DatasetMetadata, dataset_dir: Path, cause: Exception) -> str: @@ -153,16 +173,26 @@ def load_real_world_dataset( dirty_path = dataset_dir / "dirty.csv" clean_path = dataset_dir / "clean.csv" + dirty_df: pd.DataFrame | None = None + clean_df: pd.DataFrame | None = None + if not dirty_path.exists() or not clean_path.exists(): + _LOGGER.info("dataset_cache_miss name=%s dir=%s", name, dataset_dir) try: _download_to_cache(metadata, dataset_dir) except Exception as exc: # pragma: no cover - exercised through tests via monkeypatch - raise DatasetDownloadError( - _manual_download_message(metadata, dataset_dir, exc) - ) from exc - - dirty_df = _read_cached_csv(dirty_path) - clean_df = _read_cached_csv(clean_path) + fallback = _load_embedded_dataset(name) + if fallback is None: + raise DatasetDownloadError( + _manual_download_message(metadata, dataset_dir, exc) + ) from exc + dirty_df, clean_df = fallback + else: + _LOGGER.info("dataset_cache_hit name=%s dir=%s", name, dataset_dir) + + if dirty_df is None or clean_df is None: + dirty_df = _read_cached_csv(dirty_path) + clean_df = _read_cached_csv(clean_path) if len(dirty_df.index) != len(clean_df.index): raise ValueError(f"Dataset '{name}' dirty/clean row counts do not match.") diff --git a/dataforge/detectors/__init__.py b/dataforge/detectors/__init__.py index 60c07b030531c07772f7db6557449110f608f231..6281eb151ab68b19cb71711368b222e5122e0e77 100644 --- a/dataforge/detectors/__init__.py +++ b/dataforge/detectors/__init__.py @@ -12,8 +12,6 @@ deduplicated, severity-sorted issue list. from __future__ import annotations -import pandas as pd - from dataforge.detectors.base import Detector, Issue, Schema, Severity from dataforge.detectors.decimal_shift import DecimalShiftDetector from dataforge.detectors.fd_violation import FDViolationDetector @@ -33,14 +31,14 @@ __all__ = [ _SEVERITY_ORDER = {Severity.UNSAFE: 0, Severity.REVIEW: 1, Severity.SAFE: 2} -def run_all_detectors(df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]: +def run_all_detectors(df: object, schema: Schema | None = None) -> list[Issue]: """Run all registered detectors and return a merged, sorted issue list. Issues are deduplicated by (row, column, issue_type) and sorted by severity (UNSAFE first) then confidence (highest first). Args: - df: The input DataFrame to analyze. + df: The input table to analyze. schema: Optional declared schema with column types and constraints. Returns: diff --git a/dataforge/detectors/base.py b/dataforge/detectors/base.py index 214af65a0142963d3d238da9efa42c298ae6a195..32223f878c38edb03c80307a32be6efb73d21024 100644 --- a/dataforge/detectors/base.py +++ b/dataforge/detectors/base.py @@ -5,9 +5,9 @@ from __future__ import annotations import enum from typing import Literal, Protocol -import pandas as pd from pydantic import BaseModel, Field +from dataforge.table import TableLike from dataforge.verifier.schema import ( AggregateDependency, DomainBound, @@ -114,23 +114,23 @@ class Issue(BaseModel): class Detector(Protocol): """Structural protocol that every detector must implement. - A detector is a pure function over tabular data: it receives a DataFrame + A detector is a pure function over tabular data: it receives a table and an optional Schema, and returns a list of Issue objects. No LLM calls, no disk I/O, no side effects. Example: >>> class MyDetector: ... def detect( - ... self, df: pd.DataFrame, schema: Schema | None = None + ... self, df: TableLike, schema: Schema | None = None ... ) -> list[Issue]: ... return [] """ - def detect(self, df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]: + def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]: """Detect data-quality issues in the given DataFrame. Args: - df: The input DataFrame to analyze. + df: The input table to analyze. schema: Optional declared schema with column types and constraints. Returns: diff --git a/dataforge/detectors/decimal_shift.py b/dataforge/detectors/decimal_shift.py index d2f7ee85a1506f4afd502a931e06aa7cab20a3ef..e5bd6a2e3cab8d53c60a7f29502d662c752dcd18 100644 --- a/dataforge/detectors/decimal_shift.py +++ b/dataforge/detectors/decimal_shift.py @@ -10,15 +10,10 @@ The detector is **pure**: no LLM calls, no I/O, no side effects. from __future__ import annotations import math -from typing import TYPE_CHECKING - -import numpy as np -import pandas as pd +from statistics import median from dataforge.detectors.base import Issue, Schema, Severity - -if TYPE_CHECKING: - pass +from dataforge.table import TableLike, column_names, column_values # Minimum non-null numeric values required for meaningful statistics. _MIN_COLUMN_SIZE = 5 @@ -70,7 +65,7 @@ class DecimalShiftDetector: 3 """ - def detect(self, df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]: + def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]: """Detect decimal-shift issues in the DataFrame. Args: @@ -83,13 +78,13 @@ class DecimalShiftDetector: """ issues: list[Issue] = [] - for col_name in df.columns: + for col_name in column_names(df): col_issues = self._check_column(df, str(col_name)) issues.extend(col_issues) return issues - def _check_column(self, df: pd.DataFrame, col_name: str) -> list[Issue]: + def _check_column(self, df: TableLike, col_name: str) -> list[Issue]: """Check a single column for decimal-shift outliers. Args: @@ -101,7 +96,7 @@ class DecimalShiftDetector: """ # Parse all values to float, keeping track of original indices. parsed: list[tuple[int, float, str]] = [] - for row_idx, val in enumerate(df[col_name].tolist()): + for row_idx, val in enumerate(column_values(df, col_name)): fval = _try_float(val) if fval is not None: parsed.append((row_idx, fval, str(val))) @@ -109,11 +104,10 @@ class DecimalShiftDetector: if len(parsed) < _MIN_COLUMN_SIZE: return [] - values = np.array([v for _, v, _ in parsed]) - median = float(np.median(values)) + center = float(median([v for _, v, _ in parsed])) # If median is zero or very close, we cannot compute meaningful ratios. - if abs(median) < 1e-10: + if abs(center) < 1e-10: return [] issues: list[Issue] = [] @@ -121,7 +115,7 @@ class DecimalShiftDetector: if abs(fval) < 1e-10: continue - ratio = fval / median + ratio = fval / center if abs(ratio) < 1e-10: continue @@ -147,13 +141,13 @@ class DecimalShiftDetector: reason = ( f"Value {fval:g} in column '{col_name}' appears to be " f"~{int(correction_factor)}x the typical value " - f"(median ~{median:g})" + f"(median ~{center:g})" ) else: reason = ( f"Value {fval:g} in column '{col_name}' appears to be " f"~{1.0 / correction_factor:g}x too small compared to " - f"the typical value (median ~{median:g})" + f"the typical value (median ~{center:g})" ) issues.append( diff --git a/dataforge/detectors/fd_violation.py b/dataforge/detectors/fd_violation.py index 8869baca27b30a3ab958d7303a87a0f7d342a69b..35b8f5e96b0f8dba028cabda794e126533cf7bcc 100644 --- a/dataforge/detectors/fd_violation.py +++ b/dataforge/detectors/fd_violation.py @@ -12,14 +12,8 @@ The detector is **pure**: no LLM calls, no I/O, no side effects. from __future__ import annotations -from typing import TYPE_CHECKING - -import pandas as pd - from dataforge.detectors.base import Issue, Schema, Severity - -if TYPE_CHECKING: - pass +from dataforge.table import TableLike, cell_value, column_names, row_count class FDViolationDetector: @@ -49,7 +43,7 @@ class FDViolationDetector: 2 """ - def detect(self, df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]: + def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]: """Detect FD-violation issues in the DataFrame. Args: @@ -73,7 +67,7 @@ class FDViolationDetector: def _check_fd( self, - df: pd.DataFrame, + df: TableLike, determinant: tuple[str, ...], dependent: str, ) -> list[Issue]: @@ -91,34 +85,37 @@ class FDViolationDetector: # Verify all columns exist in the DataFrame. all_cols = [*determinant_columns, dependent] + available_columns = set(column_names(df)) for col in all_cols: - if col not in df.columns: + if col not in available_columns: return [] - # Drop rows with null values in determinant columns. - subset = df[all_cols].copy() - mask = subset[determinant_columns].notna().all(axis=1) - subset = subset[mask] + groups: dict[tuple[str, ...], list[int]] = {} + for row in range(row_count(df)): + group_key = tuple(cell_value(df, row, column) for column in determinant_columns) + if any(value == "" for value in group_key): + continue + groups.setdefault(group_key, []).append(row) - if subset.empty: + if not groups: return [] - # Group by determinant and find groups with multiple distinct - # dependent values. issues: list[Issue] = [] - - grouped = subset.groupby(determinant_columns, sort=False) - for group_key, group_df in grouped: - unique_deps = group_df[dependent].dropna().unique() + for group_key, row_indices in groups.items(): + unique_deps: list[str] = [] + for row in row_indices: + value = cell_value(df, row, dependent) + if value == "" or value in unique_deps: + continue + unique_deps.append(value) if len(unique_deps) <= 1: continue - # All rows in this group are part of the violation. det_desc = self._format_determinant(determinant, group_key) unique_str = ", ".join(repr(str(v)) for v in unique_deps) - for idx in group_df.index: - actual_val = str(group_df.at[idx, dependent]) + for idx in row_indices: + actual_val = cell_value(df, idx, dependent) reason = ( f"Functional dependency {determinant} -> {dependent} " f"violated: {det_desc} maps to multiple values: " diff --git a/dataforge/detectors/type_mismatch.py b/dataforge/detectors/type_mismatch.py index e45d078e12e2498d65936b2078c0501082c9bc07..a2139eccddc285e03ad9884cc63fe0c534ed9fdc 100644 --- a/dataforge/detectors/type_mismatch.py +++ b/dataforge/detectors/type_mismatch.py @@ -10,14 +10,9 @@ The detector is **pure**: no LLM calls, no I/O, no side effects. from __future__ import annotations import re -from typing import TYPE_CHECKING - -import pandas as pd from dataforge.detectors.base import Issue, Schema, Severity - -if TYPE_CHECKING: - pass +from dataforge.table import TableLike, column_names, column_values # Compiled regexes for type inference. _NUMERIC_RE = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?$") @@ -69,7 +64,7 @@ class TypeMismatchDetector: 'N/A' """ - def detect(self, df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]: + def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]: """Detect type-mismatch issues in the DataFrame. Args: @@ -84,13 +79,13 @@ class TypeMismatchDetector: """ issues: list[Issue] = [] - for col_name in df.columns: + for col_name in column_names(df): col_issues = self._check_column(df, str(col_name)) issues.extend(col_issues) return issues - def _check_column(self, df: pd.DataFrame, col_name: str) -> list[Issue]: + def _check_column(self, df: TableLike, col_name: str) -> list[Issue]: """Check a single column for type mismatches. Args: @@ -100,12 +95,10 @@ class TypeMismatchDetector: Returns: Issues found in this column. """ - series = df[col_name] - # Collect (index, value, type) for non-null entries. classified: list[tuple[int, str, str]] = [] - for row_idx, val in enumerate(series.tolist()): - if pd.isna(val): + for row_idx, val in enumerate(column_values(df, col_name)): + if val is None: continue str_val = str(val).strip() if not str_val: diff --git a/dataforge/engine/__init__.py b/dataforge/engine/__init__.py index 3d8ee17682f64c8fd3f93c6c2875823f308afee5..db0064b397454006801916edcf15b5227b5ef187 100644 --- a/dataforge/engine/__init__.py +++ b/dataforge/engine/__init__.py @@ -1 +1,33 @@ -"""Engine package scaffolding for DataForge.""" +"""Public backend engine APIs for DataForge.""" + +from dataforge.engine.repair import ( + CandidateFix, + RepairFailure, + RepairMode, + RepairPipelineRequest, + RepairPipelineResult, + RepairReceipt, + VerifiedFix, + apply_fixes_to_csv, + apply_transaction, + create_repair_transaction, + propose_repairs, + run_repair_pipeline, + source_path_lock, +) + +__all__ = [ + "CandidateFix", + "RepairFailure", + "RepairMode", + "RepairPipelineRequest", + "RepairPipelineResult", + "RepairReceipt", + "VerifiedFix", + "apply_fixes_to_csv", + "apply_transaction", + "create_repair_transaction", + "propose_repairs", + "run_repair_pipeline", + "source_path_lock", +] diff --git a/dataforge/engine/repair.py b/dataforge/engine/repair.py new file mode 100644 index 0000000000000000000000000000000000000000..523b46087c7d909208d27e07f6e9809a8df52174 --- /dev/null +++ b/dataforge/engine/repair.py @@ -0,0 +1,670 @@ +"""Public repair engine for DataForge backend surfaces. + +The engine is the stable boundary shared by CLI, Playground, MCP, and any +OpenEnv adapter that needs repair semantics. It keeps the core invariant in one +place: detect -> propose -> safety -> SMT verification -> journal/snapshot -> +atomic mutation -> byte-identical revert. +""" + +from __future__ import annotations + +import hashlib +import os +import secrets +import time +from collections.abc import Callable, Iterator +from contextlib import contextmanager, suppress +from datetime import UTC, datetime +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + +from dataforge.detectors import run_all_detectors +from dataforge.detectors.base import Issue, Schema +from dataforge.observability import repair_stage_span +from dataforge.repair_contract import CONTRACT_VERSION +from dataforge.repairers import build_repairers +from dataforge.repairers.base import ProposedFix, RepairAttempt, RetryContext +from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict +from dataforge.table import ( + Table, + TableLike, + cell_value, + column_names, + copy_table, + row_count, + set_cell_value, + table_to_csv_bytes, +) +from dataforge.table import ( + read_csv as read_table_csv, +) +from dataforge.transactions.log import ( + append_applied_event, + append_created_transaction, + cache_dir_for, + sha256_bytes, + sha256_file, + snapshot_path_for, +) +from dataforge.transactions.txn import CellFix, RepairTransaction, generate_txn_id +from dataforge.verifier import SMTVerifier, VerificationVerdict + +RepairMode = Literal["dry_run", "apply"] +EscalationResolver = Callable[ + [ProposedFix, Schema | None, SafetyContext, SafetyFilter, SafetyResult], + tuple[SafetyContext, SafetyResult], +] + + +class RepairEngineError(RuntimeError): + """Base exception for public repair engine failures.""" + + +class TransactionApplyError(RepairEngineError): + """Raised when an apply transaction cannot be completed safely.""" + + +class CandidateFix(BaseModel): + """Stable public representation of a proposed cell repair.""" + + row: int = Field(ge=0) + column: str = Field(min_length=1) + old_value: str + new_value: str + detector_id: str = Field(min_length=1) + operation: Literal["update", "delete_row"] = "update" + reason: str = Field(min_length=1) + confidence: float = Field(ge=0.0, le=1.0) + provenance: str = Field(min_length=1) + + model_config = ConfigDict(strict=True, extra="forbid", frozen=True) + + @classmethod + def from_proposed(cls, proposed_fix: ProposedFix) -> CandidateFix: + """Create a public candidate from an internal repair proposal.""" + fix = proposed_fix.fix + return cls( + row=fix.row, + column=fix.column, + old_value=fix.old_value, + new_value=fix.new_value, + detector_id=fix.detector_id, + operation=fix.operation, + reason=proposed_fix.reason, + confidence=proposed_fix.confidence, + provenance=proposed_fix.provenance, + ) + + +class VerifiedFix(CandidateFix): + """A candidate that passed safety and SMT verification.""" + + verifier_reason: str = Field(min_length=1) + + +class RepairFailure(BaseModel): + """Machine-readable account of an issue that could not be repaired.""" + + row: int = Field(ge=0) + column: str = Field(min_length=1) + issue_type: str = Field(min_length=1) + status: str = Field(min_length=1) + reason: str = Field(min_length=1) + attempt_count: int = Field(ge=1) + unsat_core: tuple[str, ...] = Field(default_factory=tuple) + + model_config = ConfigDict(strict=True, extra="forbid", frozen=True) + + @classmethod + def from_attempts(cls, attempts: list[RepairAttempt]) -> RepairFailure: + """Build a public failure record from one issue's attempt trace.""" + final = attempts[-1] + issue = final.issue + return cls( + row=issue.row, + column=issue.column, + issue_type=issue.issue_type, + status=final.status, + reason=final.reason, + attempt_count=len(attempts), + unsat_core=tuple(final.unsat_core), + ) + + +class RepairReceipt(BaseModel): + """Stable receipt for a dry-run or applied repair pipeline run.""" + + contract_version: str = CONTRACT_VERSION + mode: RepairMode + applied: bool + reversible: bool + source_path: str + source_sha256: str = Field(pattern=r"^[0-9a-f]{64}$") + post_sha256: str | None = Field(default=None, pattern=r"^[0-9a-f]{64}$") + txn_id: str | None = None + allowed_columns: list[str] = Field(default_factory=list) + valid_rows: list[int] = Field(default_factory=list) + issues_count: int = Field(ge=0) + fixes_count: int = Field(ge=0) + reason: str = Field(min_length=1) + + model_config = ConfigDict(strict=True, extra="forbid", frozen=True) + + +class RepairPipelineRequest(BaseModel): + """Input contract for running the public repair pipeline.""" + + source_path: Path + mode: RepairMode = "dry_run" + repair_schema: Schema | None = Field(default=None, alias="schema") + allow_llm: bool = False + model: str = "gemini-2.0-flash" + allow_pii: bool = False + confirm_pii: bool = False + confirm_escalations: bool = False + interactive: bool = False + create_dry_run_transaction: bool = False + + model_config = ConfigDict( + strict=True, + arbitrary_types_allowed=True, + extra="forbid", + frozen=True, + populate_by_name=True, + ) + + +class RepairPipelineResult(BaseModel): + """Output contract for a public repair pipeline run.""" + + receipt: RepairReceipt + issues: list[Issue] + fixes: list[VerifiedFix] + failures: list[RepairFailure] = Field(default_factory=list) + transaction: RepairTransaction | None = None + + model_config = ConfigDict( + strict=True, arbitrary_types_allowed=True, extra="forbid", frozen=True + ) + + +def _atomic_write_bytes(path: Path, payload: bytes) -> None: + """Write bytes to ``path`` through an atomic same-directory replacement.""" + resolved = path.resolve() + resolved.parent.mkdir(parents=True, exist_ok=True) + temp_path = resolved.with_name(f".{resolved.name}.{secrets.token_hex(8)}.tmp") + try: + with temp_path.open("xb") as handle: + handle.write(payload) + handle.flush() + os.fsync(handle.fileno()) + os.replace(temp_path, resolved) + finally: + if temp_path.exists(): + temp_path.unlink() + + +def read_csv(path: Path) -> Table: + """Read a CSV using conservative string-preserving defaults.""" + return read_table_csv(path) + + +def _csv_bytes_after_fixes(path: Path, fixes: list[CellFix]) -> bytes: + """Validate fixes against a CSV and return the mutated CSV bytes.""" + df = read_csv(path) + for fix in fixes: + if fix.operation != "update": + raise ValueError(f"Unsupported repair operation '{fix.operation}' for row {fix.row}.") + if fix.column not in column_names(df): + raise ValueError(f"Column '{fix.column}' not found in '{path}'.") + if fix.row < 0 or fix.row >= row_count(df): + raise ValueError(f"Row {fix.row} is out of bounds for '{path}'.") + + current_value = cell_value(df, fix.row, fix.column) + if current_value != fix.old_value: + raise ValueError( + f"Refusing to apply stale fix for row {fix.row}, column '{fix.column}': " + f"expected '{fix.old_value}', found '{current_value}'." + ) + set_cell_value(df, fix.row, fix.column, fix.new_value) + + return table_to_csv_bytes(df) + + +def apply_fixes_to_csv(path: Path, fixes: list[CellFix]) -> str: + """Atomically apply ordered cell fixes to a CSV and return post-state SHA-256.""" + payload = _csv_bytes_after_fixes(path, fixes) + _atomic_write_bytes(path, payload) + return hashlib.sha256(payload).hexdigest() + + +def _lock_path_for(source_path: Path) -> Path: + """Return the filesystem lock path for a source file.""" + digest = hashlib.sha256(str(source_path.resolve()).encode("utf-8")).hexdigest()[:24] + return source_path.resolve().parent / ".dataforge" / "locks" / f"{digest}.lock" + + +@contextmanager +def source_path_lock( + source_path: Path, + *, + timeout_seconds: float = 5.0, + stale_after_seconds: float = 300.0, +) -> Iterator[None]: + """Acquire an exclusive lock for a source path using an atomic lock file.""" + lock_path = _lock_path_for(source_path) + lock_path.parent.mkdir(parents=True, exist_ok=True) + deadline = time.monotonic() + timeout_seconds + while True: + try: + fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + try: + payload = f"{os.getpid()} {datetime.now(UTC).isoformat()}\n".encode() + os.write(fd, payload) + finally: + os.close(fd) + break + except FileExistsError as exc: + try: + age = time.time() - lock_path.stat().st_mtime + except OSError: + age = 0.0 + if age > stale_after_seconds: + try: + lock_path.unlink() + continue + except OSError: + pass + if time.monotonic() >= deadline: + raise TransactionApplyError( + f"Timed out waiting for DataForge source lock: {source_path.resolve()}" + ) from exc + time.sleep(0.05) + + try: + yield + finally: + with suppress(FileNotFoundError): + lock_path.unlink() + + +def _write_snapshot_once(snapshot_path: Path, source_bytes: bytes) -> None: + """Write an immutable snapshot and fail if the transaction id already exists.""" + snapshot_path.parent.mkdir(parents=True, exist_ok=True) + try: + with snapshot_path.open("xb") as handle: + handle.write(source_bytes) + handle.flush() + os.fsync(handle.fileno()) + except FileExistsError as exc: + raise TransactionApplyError( + f"Transaction snapshot already exists: {snapshot_path}" + ) from exc + + +def create_repair_transaction( + path: Path, + fixes: list[ProposedFix], + source_bytes: bytes, + *, + txn_id: str | None = None, +) -> tuple[RepairTransaction, Path]: + """Create an unapplied transaction journal and immutable source snapshot.""" + resolved_path = path.resolve() + transaction_id = txn_id or generate_txn_id() + snapshot_path = snapshot_path_for(resolved_path, transaction_id) + _write_snapshot_once(snapshot_path, source_bytes) + + transaction = RepairTransaction( + txn_id=transaction_id, + created_at=datetime.now(UTC), + source_path=str(resolved_path), + source_sha256=sha256_bytes(source_bytes), + source_snapshot_path=str(snapshot_path.resolve()), + fixes=[proposal.fix for proposal in fixes], + applied=False, + ) + try: + log_path = append_created_transaction(transaction) + except Exception: + snapshot_path.unlink(missing_ok=True) + raise + return transaction, log_path + + +def apply_transaction( + path: Path, + fixes: list[ProposedFix], + source_bytes: bytes, + *, + txn_id: str | None = None, +) -> str: + """Journal, snapshot, atomically apply fixes, and restore bytes on failure.""" + resolved_path = path.resolve() + with source_path_lock(resolved_path): + current_bytes = resolved_path.read_bytes() + if current_bytes != source_bytes: + raise TransactionApplyError( + "Refusing to apply repairs because the source file changed after detection." + ) + + with repair_stage_span("dataforge.repair.transaction.create", fixes_count=len(fixes)): + transaction, log_path = create_repair_transaction( + resolved_path, + fixes, + source_bytes, + txn_id=txn_id, + ) + try: + with repair_stage_span("dataforge.repair.transaction.apply", fixes_count=len(fixes)): + post_sha256 = apply_fixes_to_csv( + resolved_path, + [proposal.fix for proposal in fixes], + ) + append_applied_event(log_path, transaction.txn_id, post_sha256=post_sha256) + except Exception as exc: + _atomic_write_bytes(resolved_path, source_bytes) + if sha256_file(resolved_path) != transaction.source_sha256: + raise TransactionApplyError( + "Apply failed and the source file could not be restored to original bytes." + ) from exc + raise + + return transaction.txn_id + + +def _build_retry_context(issue: Issue, attempts: list[RepairAttempt]) -> RetryContext: + """Build retry hints from previous failed attempts.""" + rejected_values = frozenset( + attempt.fix.fix.new_value + for attempt in attempts + if attempt.fix is not None and attempt.status in {"denied", "rejected", "unknown"} + ) + hints: list[str] = [] + for attempt in attempts: + hints.append(attempt.reason) + hints.extend(attempt.unsat_core) + return RetryContext( + issue=issue, + previous_attempts=tuple(attempts), + rejected_values=rejected_values, + hints=tuple(hints), + ) + + +def propose_repairs( + issues: list[Issue], + path: Path, + working_df: TableLike, + schema: Schema | None, + *, + allow_llm: bool, + model: str, + allow_pii: bool, + confirm_pii: bool, + confirm_escalations: bool, + interactive: bool, + escalation_resolver: EscalationResolver | None = None, +) -> tuple[list[ProposedFix], list[list[RepairAttempt]]]: + """Run repairers and gates issue-by-issue against a working dataframe.""" + with repair_stage_span("dataforge.repair.repairers.build", allow_llm=allow_llm): + repairers = build_repairers( + cache_dir=cache_dir_for(path), + allow_llm=allow_llm, + model=model, + ) + safety_filter = SafetyFilter() + verifier = SMTVerifier() + safety_context = SafetyContext( + allow_pii=allow_pii, + confirm_pii=confirm_pii, + confirm_escalations=confirm_escalations, + ) + + accepted_fixes: list[ProposedFix] = [] + attempt_groups: list[list[RepairAttempt]] = [] + + for issue in issues: + attempts: list[RepairAttempt] = [] + repairer = repairers.get(issue.issue_type) + if repairer is None: + attempts.append( + RepairAttempt( + issue=issue, + attempt_number=1, + status="attempted_not_fixed", + reason="No repairer is registered for this issue type.", + ) + ) + attempt_groups.append(attempts) + continue + + accepted = False + retry_context = RetryContext(issue=issue) + for attempt_number in range(1, 4): + candidate = repairer.propose(issue, working_df, schema, retry_context=retry_context) + if candidate is None: + attempts.append( + RepairAttempt( + issue=issue, + attempt_number=attempt_number, + status="attempted_not_fixed", + reason="No repair proposal was available for this issue.", + ) + ) + break + + preferred = safety_filter.choose_preferred([candidate], schema, safety_context) + safety_result = safety_filter.evaluate(preferred, schema, safety_context) + if ( + safety_result.verdict == SafetyVerdict.ESCALATE + and interactive + and escalation_resolver is not None + ): + safety_context, safety_result = escalation_resolver( + preferred, + schema, + safety_context, + safety_filter, + safety_result, + ) + + if safety_result.verdict == SafetyVerdict.DENY: + attempts.append( + RepairAttempt( + issue=issue, + attempt_number=attempt_number, + fix=preferred, + status="denied", + reason=safety_result.reason, + ) + ) + retry_context = _build_retry_context(issue, attempts) + continue + + if safety_result.verdict == SafetyVerdict.ESCALATE: + attempts.append( + RepairAttempt( + issue=issue, + attempt_number=attempt_number, + fix=preferred, + status="escalated", + reason=safety_result.reason, + ) + ) + break + + with repair_stage_span( + "dataforge.repair.verifier.verify", + issue_type=issue.issue_type, + row=issue.row, + ): + verifier_result = verifier.verify(working_df, [preferred], schema) + if verifier_result.verdict == VerificationVerdict.ACCEPT: + accepted_fixes.append(preferred) + set_cell_value( + working_df, + preferred.fix.row, + preferred.fix.column, + preferred.fix.new_value, + ) + attempts.append( + RepairAttempt( + issue=issue, + attempt_number=attempt_number, + fix=preferred, + status="accepted", + reason=verifier_result.reason, + ) + ) + accepted = True + break + + attempts.append( + RepairAttempt( + issue=issue, + attempt_number=attempt_number, + fix=preferred, + status=( + "rejected" + if verifier_result.verdict == VerificationVerdict.REJECT + else "unknown" + ), + reason=verifier_result.reason, + unsat_core=verifier_result.unsat_core, + ) + ) + retry_context = _build_retry_context(issue, attempts) + + if ( + not accepted + and attempts + and attempts[-1].status not in {"attempted_not_fixed", "escalated"} + ): + last_reason = attempts[-1].reason + attempts[-1] = attempts[-1].model_copy( + update={ + "status": "attempted_not_fixed", + "reason": ( + f"Issue was attempted but not fixed after {len(attempts)} attempt(s). " + f"Last failure: {last_reason}" + ), + } + ) + attempt_groups.append(attempts) + + return accepted_fixes, attempt_groups + + +def _verified_fixes( + fixes: list[ProposedFix], + attempt_groups: list[list[RepairAttempt]], +) -> list[VerifiedFix]: + """Build public verified fix payloads using accepted attempt reasons.""" + accepted_reasons: dict[tuple[int, str, str], str] = {} + for attempts in attempt_groups: + for attempt in attempts: + if attempt.status == "accepted" and attempt.fix is not None: + fix = attempt.fix.fix + accepted_reasons[(fix.row, fix.column, fix.new_value)] = attempt.reason + + return [ + VerifiedFix( + **CandidateFix.from_proposed(fix).model_dump(), + verifier_reason=accepted_reasons.get( + (fix.fix.row, fix.fix.column, fix.fix.new_value), + "Accepted by verifier.", + ), + ) + for fix in fixes + ] + + +def _failed_attempts(attempt_groups: list[list[RepairAttempt]]) -> list[RepairFailure]: + """Return failures for issue groups whose final status was not accepted.""" + return [ + RepairFailure.from_attempts(attempts) + for attempts in attempt_groups + if attempts and attempts[-1].status != "accepted" + ] + + +def run_repair_pipeline(request: RepairPipelineRequest) -> RepairPipelineResult: + """Run the public repair pipeline from detection through optional apply.""" + source_path = request.source_path.resolve() + source_bytes = source_path.read_bytes() + df = read_csv(source_path) + with repair_stage_span("dataforge.repair.detect", row_count=row_count(df)): + issues = run_all_detectors(df, request.repair_schema) + with repair_stage_span("dataforge.repair.propose", issue_count=len(issues)): + accepted_fixes, attempt_groups = propose_repairs( + issues, + source_path, + copy_table(df), + request.repair_schema, + allow_llm=request.allow_llm, + model=request.model, + allow_pii=request.allow_pii, + confirm_pii=request.confirm_pii, + confirm_escalations=request.confirm_escalations, + interactive=request.interactive, + ) + + with repair_stage_span("dataforge.repair.safety.batch", fixes_count=len(accepted_fixes)): + batch_safety = SafetyFilter().evaluate_batch(accepted_fixes) + failures = _failed_attempts(attempt_groups) + transaction: RepairTransaction | None = None + txn_id: str | None = None + post_sha256: str | None = None + applied = False + reason = "No accepted fixes were produced." + + if batch_safety.verdict != SafetyVerdict.ALLOW: + accepted_fixes = [] + reason = batch_safety.reason + elif request.mode == "apply" and accepted_fixes: + txn_id = apply_transaction(source_path, accepted_fixes, source_bytes) + post_sha256 = sha256_file(source_path) + applied = True + reason = f"Applied {len(accepted_fixes)} fix(es)." + elif request.create_dry_run_transaction: + transaction, _log_path = create_repair_transaction( + source_path, accepted_fixes, source_bytes + ) + txn_id = transaction.txn_id + reason = ( + "Dry run completed without mutating the source file." + if accepted_fixes + else "No accepted fixes were produced." + ) + elif accepted_fixes: + reason = "Dry run completed without mutating the source file." + + if txn_id is not None and transaction is None: + # Replaying the log is unnecessary for the public contract here; this + # minimal receipt is intentionally enough for API callers. + transaction = None + + receipt = RepairReceipt( + mode=request.mode, + applied=applied, + reversible=True, + source_path=str(source_path), + source_sha256=sha256_bytes(source_bytes), + post_sha256=post_sha256, + txn_id=txn_id, + allowed_columns=column_names(df), + valid_rows=list(range(row_count(df))), + issues_count=len(issues), + fixes_count=len(accepted_fixes), + reason=reason, + ) + return RepairPipelineResult( + receipt=receipt, + issues=issues, + fixes=_verified_fixes(accepted_fixes, attempt_groups), + failures=failures, + transaction=transaction, + ) diff --git a/dataforge/env/__init__.py b/dataforge/env/__init__.py index f3dde09c1736f14edf88dbfb8e943bbc169c2d30..c5faa3c7ae8d16193c24d61ca20ff4966dc90182 100644 --- a/dataforge/env/__init__.py +++ b/dataforge/env/__init__.py @@ -1 +1,22 @@ -"""Environment package scaffolding for DataForge.""" +"""DataForge RL environment — OpenEnv-compatible data-quality environment. + +Public API: + DataForgeEnv — Core environment with reset/step/state/close. + ResetResult — Return type of reset(). + StepResult — Return type of step(). + EnvState — State snapshot from state(). + DataForgeObservation — Agent-visible observation. + ToolResult — Structured result from each action. +""" + +from dataforge.env.environment import DataForgeEnv, EnvState, ResetResult, StepResult +from dataforge.env.observation import DataForgeObservation, ToolResult + +__all__ = [ + "DataForgeEnv", + "DataForgeObservation", + "EnvState", + "ResetResult", + "StepResult", + "ToolResult", +] diff --git a/dataforge/env/environment.py b/dataforge/env/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..5edca5715073659649f0b0a817f150b191a08d99 --- /dev/null +++ b/dataforge/env/environment.py @@ -0,0 +1,884 @@ +"""OpenEnv-compatible DataForge RL environment. + +Core environment implementing reset/step/state/close for data-quality +detection, diagnosis, and repair with typed tool-use actions. + +No LLM calls. No disk writes. Dataset state is in-memory per episode. +""" + +from __future__ import annotations + +import logging +import random +import re +import uuid +from difflib import SequenceMatcher +from pathlib import Path +from typing import Any, cast + +import duckdb +import pandas as pd +import sqlglot +import sqlglot.expressions as sqlglot_exp +from pydantic import BaseModel, Field + +from dataforge.agent.scratchpad import Scratchpad +from dataforge.agent.tool_actions import ( + Action, + Diagnose, + Fix, + Hypothesis, + InspectRows, + PatternMatch, + RootCause, + SqlQuery, + StatTest, + parse_action, +) +from dataforge.detectors import run_all_detectors +from dataforge.detectors.base import Issue +from dataforge.env.observation import DataForgeObservation, ToolResult +from dataforge.env.reward import ( + P_FALSE_POS, + P_INVALID, + P_WRONG_FIX, + R_EXPLORE, + R_ROOT_CAUSE, + EpisodeMetrics, + RewardEngine, +) + +logger = logging.getLogger("dataforge.env") + +__all__ = [ + "DataForgeEnv", + "EnvState", + "ResetResult", + "StepResult", +] + +_FIXTURES_DIR = Path(__file__).resolve().parents[1].parent / "fixtures" +_DEFAULT_CSV = _FIXTURES_DIR / "hospital_10rows.csv" +_DEFAULT_SCHEMA = _FIXTURES_DIR / "hospital_schema.yaml" +_MAX_STEPS = 30 +_MAX_RESULT_ROWS = 20 +_TOOL_HISTORY_LIMIT = 5 +_NOISE_EPSILON = 0.15 +_BLOCKED_SQL_FRAGMENTS = ( + "attach", + "call ", + "copy ", + "detach", + "duckdb_extensions", + "filename", + "from_csv_auto", + "glob(", + "http://", + "https://", + "httpfs", + "install", + "load ", + "mysql_scan", + "parquet_scan", + "postgres_scan", + "pragma", + "read_csv", + "read_json", + "read_parquet", + "s3://", + "sqlite_scan", +) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Result models +# ═══════════════════════════════════════════════════════════════════════════ + + +class ResetResult(BaseModel): + """Result of env.reset().""" + + observation: DataForgeObservation + info: dict[str, Any] = Field(default_factory=dict) + + +class StepResult(BaseModel): + """Result of env.step().""" + + observation: DataForgeObservation + reward: float = 0.0 + done: bool = False + info: dict[str, Any] = Field(default_factory=dict) + + +class EnvState(BaseModel): + """Internal environment state snapshot.""" + + episode_id: str = "" + step_count: int = 0 + task_id: str = "" + issues_detected: int = 0 + issues_fixed: int = 0 + false_positives: int = 0 + total_issues: int = 0 + is_done: bool = False + + +# ═══════════════════════════════════════════════════════════════════════════ +# Environment +# ═══════════════════════════════════════════════════════════════════════════ + + +class DataForgeEnv: + """OpenEnv-compatible RL environment for data quality repair. + + Core API: ``reset()``, ``step()``, ``state()``, ``close()`` (no-op). + + Example:: + + >>> env = DataForgeEnv() + >>> result = env.reset(seed=42) + >>> result.observation.done + False + """ + + def __init__(self, max_steps: int = _MAX_STEPS) -> None: + self._max_steps = max_steps + self._episode_id = "" + self._step_count = 0 + self._df: pd.DataFrame = pd.DataFrame() + self._ground_truth: list[Issue] = [] + self._found_issues: list[dict[str, Any]] = [] + self._fixed_issues: list[dict[str, Any]] = [] + self._false_positives = 0 + self._cumulative_reward = 0.0 + self._is_done = False + self._inspected_rows: set[int] = set() + self._noisy = False + self._noise_rng: random.Random | None = None + self._scratchpad = Scratchpad() + self._tool_history: list[ToolResult] = [] + self._reward_engine = RewardEngine() + self._schema_info: dict[str, str] = {} + self._causal_dag_cache: Any = None + self._root_cause_labels: set[int] = set() + + # ── Core API ────────────────────────────────────────────────────────── + + def reset(self, seed: int | None = None, *, noisy: bool = False) -> ResetResult: + """Reset the environment for a new episode. + + Args: + seed: Optional RNG seed for deterministic episodes. + noisy: If True, enable observation noise (epsilon=0.15). + + Returns: + ResetResult with initial observation. + """ + self._episode_id = str(uuid.uuid4()) + self._step_count = 0 + self._found_issues = [] + self._fixed_issues = [] + self._false_positives = 0 + self._cumulative_reward = 0.0 + self._is_done = False + self._inspected_rows = set() + self._scratchpad.reset() + self._tool_history = [] + self._causal_dag_cache = None + self._root_cause_labels = set() + self._noisy = noisy + self._noise_rng = random.Random(seed if seed is not None else 0) if noisy else None + + # Load fixture dataset + self._df = pd.read_csv(_DEFAULT_CSV, dtype=str) + self._schema_info = dict.fromkeys(self._df.columns, "str") + if _DEFAULT_SCHEMA.exists(): + import yaml + + with open(_DEFAULT_SCHEMA, encoding="utf-8") as f: + schema_data = yaml.safe_load(f) + if schema_data and "columns" in schema_data: + self._schema_info = schema_data["columns"] + + # Run detectors for hidden ground truth + self._ground_truth = run_all_detectors(self._df) + logger.info( + "Episode %s: %d rows, %d ground-truth issues", + self._episode_id[:8], + len(self._df), + len(self._ground_truth), + ) + + # Initial observation with first 5 rows + initial_rows = cast(list[dict[str, Any]], self._df.head(5).to_dict(orient="records")) + obs = DataForgeObservation( + visible_rows=initial_rows, + step_budget_remaining=self._max_steps, + scratchpad_summary=self._scratchpad.summary(), + metadata={ + "episode_id": self._episode_id, + "total_rows": len(self._df), + "total_columns": len(self._df.columns), + "schema": self._schema_info, + }, + ) + return ResetResult(observation=obs, info={"episode_id": self._episode_id}) + + def step(self, action: Action | dict[str, Any]) -> StepResult: + """Execute one agent action and return the result. + + Args: + action: A typed Action model or raw dict to be parsed. + + Returns: + StepResult with observation, reward, and done flag. + """ + if self._is_done: + return self._terminal_result(0.0) + + self._step_count += 1 + + # Parse if raw dict + if isinstance(action, dict): + try: + action = parse_action(action) + except Exception as exc: + return self._error_step(str(exc)) + + # Dispatch + try: + tool_result, reward = self._dispatch(action) + except Exception as exc: + logger.exception("Action dispatch error at step %d", self._step_count) + return self._error_step(str(exc)) + + # Late-step penalty + reward += self._reward_engine.compute_late_penalty(self._step_count, self._max_steps) + + # Accumulate + self._cumulative_reward += reward + + # Record in history + self._tool_history.append(tool_result) + if len(self._tool_history) > _TOOL_HISTORY_LIMIT: + self._tool_history = self._tool_history[-_TOOL_HISTORY_LIMIT:] + + # Check termination + done = self._step_count >= self._max_steps + if done: + self._is_done = True + terminal = self._compute_terminal() + self._cumulative_reward = max(self._cumulative_reward, terminal) + + obs = DataForgeObservation( + visible_rows=tool_result.data + if tool_result.action_type == "INSPECT_ROWS" and tool_result.success + else None, + scratchpad_summary=self._scratchpad.summary(), + step_budget_remaining=max(0, self._max_steps - self._step_count), + tool_usage_history=list(self._tool_history), + latest_result=tool_result, + done=done, + reward=reward, + cumulative_reward=self._cumulative_reward, + ) + return StepResult(observation=obs, reward=reward, done=done) + + def state(self) -> EnvState: + """Return current internal state snapshot.""" + return EnvState( + episode_id=self._episode_id, + step_count=self._step_count, + issues_detected=len(self._found_issues), + issues_fixed=len(self._fixed_issues), + false_positives=self._false_positives, + total_issues=len(self._ground_truth), + is_done=self._is_done, + ) + + def close(self) -> None: + """No-op. Retained for OpenEnv container compatibility.""" + + # ── Dispatch ────────────────────────────────────────────────────────── + + def _dispatch(self, action: Action) -> tuple[ToolResult, float]: + """Route action to handler. Returns (tool_result, step_reward).""" + if isinstance(action, InspectRows): + return self._handle_inspect(action) + if isinstance(action, SqlQuery): + return self._handle_sql(action) + if isinstance(action, StatTest): + return self._handle_stat(action) + if isinstance(action, PatternMatch): + return self._handle_pattern(action) + if isinstance(action, Hypothesis): + return self._handle_hypothesis(action) + if isinstance(action, RootCause): + return self._handle_root_cause(action) + if isinstance(action, Diagnose): + return self._handle_diagnose(action) + if isinstance(action, Fix): + return self._handle_fix(action) + return ToolResult( + action_type="UNKNOWN", + success=False, + error={"verdict": "error", "reason": "Unknown action type"}, + ), P_INVALID + + # ── Action handlers ─────────────────────────────────────────────────── + + def _handle_inspect(self, action: InspectRows) -> tuple[ToolResult, float]: + """Handle INSPECT_ROWS: return dataset rows.""" + valid_indices = [i for i in action.row_indices if 0 <= i < len(self._df)] + if not valid_indices: + return ToolResult( + action_type="INSPECT_ROWS", + success=False, + error={"verdict": "error", "reason": "No valid row indices"}, + ), P_INVALID + + # Apply 20-row cap + valid_indices = valid_indices[:20] + rows = self._df.iloc[valid_indices] + if action.column_names: + valid_cols = [c for c in action.column_names if c in self._df.columns] + if valid_cols: + rows = rows[valid_cols] + + row_dicts = cast(list[dict[str, Any]], rows.to_dict(orient="records")) + for i, idx in enumerate(valid_indices[: len(row_dicts)]): + row_dicts[i]["_row_index"] = idx + + # Noise injection + if self._noisy and self._noise_rng: + row_dicts = self._inject_noise(row_dicts) + + # Exploration bonus + new_indices = set(valid_indices) - self._inspected_rows + self._inspected_rows.update(valid_indices) + gt_rows = {issue.row for issue in self._ground_truth} + found_rows = {f["row"] for f in self._found_issues} + bonus = self._reward_engine.compute_exploration_bonus( + new_indices, + self._inspected_rows, + len(self._df), + gt_rows, + found_rows, + ) + return ToolResult(action_type="INSPECT_ROWS", success=True, data=row_dicts), bonus + + def _handle_sql(self, action: SqlQuery) -> tuple[ToolResult, float]: + """Handle SQL_QUERY: execute read-only SQL via DuckDB.""" + # Validate read-only + try: + parsed = [stmt for stmt in sqlglot.parse(action.query) if stmt is not None] + except sqlglot.errors.ParseError as exc: + return ToolResult( + action_type="SQL_QUERY", + success=False, + error={ + "verdict": "error", + "reason": str(exc), + "suggested_constraint": "Use valid SQL syntax", + }, + ), P_INVALID + + if len(parsed) != 1: + return ToolResult( + action_type="SQL_QUERY", + success=False, + error={ + "verdict": "rejected", + "reason": "Exactly one SELECT statement is allowed.", + "suggested_constraint": "Use a single read-only SELECT statement.", + }, + ), P_INVALID + + normalized_query = f" {action.query.lower()} " + blocked = next( + (fragment for fragment in _BLOCKED_SQL_FRAGMENTS if fragment in normalized_query), + None, + ) + if blocked is not None: + return ToolResult( + action_type="SQL_QUERY", + success=False, + error={ + "verdict": "rejected", + "reason": "SQL_QUERY may only read from the registered data relation.", + "suggested_constraint": "Query the in-memory data table without file, network, extension, or table functions.", + }, + ), P_INVALID + + for stmt in parsed: + if stmt.key not in ("select",): + return ToolResult( + action_type="SQL_QUERY", + success=False, + error={ + "verdict": "rejected", + "reason": f"Only SELECT queries allowed, got {stmt.key.upper()}", + "suggested_constraint": "Use SELECT statements only", + }, + ), P_INVALID + + for table in stmt.find_all(sqlglot_exp.Table): + if table.name.lower() != "data": + return ToolResult( + action_type="SQL_QUERY", + success=False, + error={ + "verdict": "rejected", + "reason": ( + "SQL_QUERY may only reference the registered data relation; " + f"got '{table.name}'." + ), + "suggested_constraint": "Use FROM data for tabular queries.", + }, + ), P_INVALID + + try: + conn = duckdb.connect(":memory:") + conn.register("data", self._df) + result_df = conn.execute(action.query).fetchdf() + conn.close() + rows = result_df.head(_MAX_RESULT_ROWS).to_dict(orient="records") + return ToolResult(action_type="SQL_QUERY", success=True, data=rows), 0.0 + except duckdb.Error as exc: + return ToolResult( + action_type="SQL_QUERY", + success=False, + error={"verdict": "error", "reason": str(exc)}, + ), P_INVALID + + def _handle_stat(self, action: StatTest) -> tuple[ToolResult, float]: + """Handle STAT_TEST: run zscore/iqr/ks on a column.""" + if action.column not in self._df.columns: + return ToolResult( + action_type="STAT_TEST", + success=False, + error={"verdict": "error", "reason": f"Column '{action.column}' not found"}, + ), P_INVALID + + try: + col = pd.to_numeric(self._df[action.column], errors="coerce").dropna() + if len(col) == 0: + return ToolResult( + action_type="STAT_TEST", + success=False, + error={ + "verdict": "error", + "reason": f"No numeric values in column '{action.column}'", + }, + ), P_INVALID + except Exception as exc: + return ToolResult( + action_type="STAT_TEST", + success=False, + error={"verdict": "error", "reason": str(exc)}, + ), P_INVALID + + from scipy import stats as scipy_stats # type: ignore[import-untyped] + + if action.test_type == "zscore": + zscores = scipy_stats.zscore(col) + threshold = action.threshold or 3.0 + outliers = col.index[abs(zscores) > threshold].tolist() + data = { + "test": "zscore", + "threshold": threshold, + "outlier_indices": outliers, + "n_outliers": len(outliers), + "mean": float(col.mean()), + "std": float(col.std()), + } + elif action.test_type == "iqr": + q1, q3 = float(col.quantile(0.25)), float(col.quantile(0.75)) + iqr_val = q3 - q1 + factor = action.threshold or 1.5 + lower, upper = q1 - factor * iqr_val, q3 + factor * iqr_val + outliers = col.index[(col < lower) | (col > upper)].tolist() + data = { + "test": "iqr", + "q1": q1, + "q3": q3, + "iqr": iqr_val, + "lower": lower, + "upper": upper, + "outlier_indices": outliers, + } + elif action.test_type == "ks": + stat_val, p_val = scipy_stats.kstest( + col, "norm", args=(float(col.mean()), float(col.std())) + ) + data = { + "test": "ks", + "statistic": float(stat_val), + "p_value": float(p_val), + "normal": p_val > 0.05, + } + else: + return ToolResult( + action_type="STAT_TEST", + success=False, + error={"verdict": "error", "reason": f"Unknown test type: {action.test_type}"}, + ), P_INVALID + + return ToolResult(action_type="STAT_TEST", success=True, data=data), 0.0 + + def _handle_pattern(self, action: PatternMatch) -> tuple[ToolResult, float]: + """Handle PATTERN_MATCH: evaluate regex against column values.""" + if action.column not in self._df.columns: + return ToolResult( + action_type="PATTERN_MATCH", + success=False, + error={"verdict": "error", "reason": f"Column '{action.column}' not found"}, + ), P_INVALID + + try: + compiled = re.compile(action.pattern) + except re.error as exc: + return ToolResult( + action_type="PATTERN_MATCH", + success=False, + error={"verdict": "error", "reason": f"Invalid regex: {exc}"}, + ), P_INVALID + + matches: list[dict[str, Any]] = [] + for idx, val in enumerate(self._df[action.column].astype(str)): + is_match = bool(compiled.search(val)) + if is_match == action.expect_match: + matches.append({"row": idx, "column": action.column, "value": val}) + return ToolResult( + action_type="PATTERN_MATCH", + success=True, + data={"matches": matches[:_MAX_RESULT_ROWS], "total_matches": len(matches)}, + ), 0.0 + + def _handle_hypothesis(self, action: Hypothesis) -> tuple[ToolResult, float]: + """Handle HYPOTHESIS: record claim and award root-cause credit.""" + self._scratchpad.add_hypothesis( + action.claim, + action.affected_rows, + action.affected_columns, + action.root_cause_type, + ) + # Check for root-cause match against ground truth + credit = 0.0 + for issue in self._ground_truth: + if ( + issue.row in action.affected_rows + and issue.column in action.affected_columns + and issue.issue_type == action.root_cause_type + ): + credit += R_EXPLORE + data = {"recorded": True, "root_cause_credit": credit} + return ToolResult(action_type="HYPOTHESIS", success=True, data=data), credit + + def _handle_root_cause(self, action: RootCause) -> tuple[ToolResult, float]: + """Handle ROOT_CAUSE: analyze detected issues for minimal roots.""" + if not self._found_issues: + return ToolResult( + action_type="ROOT_CAUSE", + success=False, + error={"verdict": "error", "reason": "No detected issues are available"}, + ), P_INVALID + + invalid = [idx for idx in action.error_indices if idx >= len(self._found_issues)] + if invalid: + return ToolResult( + action_type="ROOT_CAUSE", + success=False, + error={ + "verdict": "error", + "reason": f"Detected issue indices out of range: {invalid}", + }, + ), P_INVALID + + from dataforge.causal.pc import discover_causal_dag + from dataforge.causal.root_cause import CausalRootCauseAnalyzer, evidence_from_issue + + if self._causal_dag_cache is None: + self._causal_dag_cache = discover_causal_dag(self._df).dag + + selected = [ + evidence_from_issue(index, self._found_issues[index]) for index in action.error_indices + ] + result = CausalRootCauseAnalyzer(self._causal_dag_cache).analyze(selected) + data = result.model_dump(mode="json") + reward = self._root_cause_reward(set(result.root_indices)) + return ToolResult(action_type="ROOT_CAUSE", success=True, data=data), reward + + def _handle_diagnose(self, action: Diagnose) -> tuple[ToolResult, float]: + """Handle DIAGNOSE: score against ground truth.""" + if action.row < 0 or action.row >= len(self._df): + return ToolResult( + action_type="DIAGNOSE", + success=False, + error={"verdict": "error", "reason": f"Row {action.row} out of bounds"}, + ), P_INVALID + if action.column not in self._df.columns: + return ToolResult( + action_type="DIAGNOSE", + success=False, + error={"verdict": "error", "reason": f"Column '{action.column}' not found"}, + ), P_INVALID + + # Already reported? + for found in self._found_issues: + if found["row"] == action.row and found["column"] == action.column: + return ToolResult( + action_type="DIAGNOSE", success=True, data={"result": "already_found"} + ), 0.0 + + # Match ground truth + for issue in self._ground_truth: + if issue.row == action.row and issue.column == action.column: + type_match = action.issue_type == issue.issue_type + reward = self._reward_engine.diagnose_reward(type_match) + self._found_issues.append( + {"row": action.row, "column": action.column, "type": action.issue_type} + ) + self._scratchpad.confirm_issue(action.row, action.column, action.issue_type) + return ToolResult( + action_type="DIAGNOSE", + success=True, + data={"result": "correct", "type_match": type_match}, + ), reward + + # False positive + self._false_positives += 1 + return ToolResult( + action_type="DIAGNOSE", success=True, data={"result": "false_positive"} + ), P_FALSE_POS + + def _root_cause_reward(self, root_indices: set[int]) -> float: + """Return root-cause bonus only when task labels are available.""" + if not self._root_cause_labels: + return 0.0 + return R_ROOT_CAUSE if root_indices == self._root_cause_labels else 0.0 + + def _handle_fix(self, action: Fix) -> tuple[ToolResult, float]: + """Handle FIX: validate through safety/SMT, then score.""" + if action.row < 0 or action.row >= len(self._df): + return ToolResult( + action_type="FIX", + success=False, + error={"verdict": "error", "reason": f"Row {action.row} out of bounds"}, + ), P_INVALID + if action.column not in self._df.columns: + return ToolResult( + action_type="FIX", + success=False, + error={"verdict": "error", "reason": f"Column '{action.column}' not found"}, + ), P_INVALID + + # Already fixed? + for fixed in self._fixed_issues: + if fixed["row"] == action.row and fixed["column"] == action.column: + return ToolResult( + action_type="FIX", success=True, data={"result": "already_fixed"} + ), 0.0 + + # Safety filter + SMT verifier (best-effort, no crash on import failure) + try: + safety_ok, safety_msg = self._check_safety(action) + except Exception as exc: + logger.warning("Safety pipeline failed closed: %s", exc) + safety_ok = False + safety_msg = f"Safety pipeline failed closed: {exc}" + if not safety_ok: + return ToolResult( + action_type="FIX", + success=False, + error={"verdict": "rejected", "reason": safety_msg}, + ), P_INVALID + + # Match ground truth + for issue in self._ground_truth: + if issue.row == action.row and issue.column == action.column: + if issue.expected is None: + return ToolResult( + action_type="FIX", success=True, data={"result": "detection_only"} + ), 0.0 + + # Exact match (case-insensitive) + if action.new_value.strip().lower() == str(issue.expected).lower(): + reward = self._reward_engine.fix_reward( + exact=True, has_justification=bool(action.justification) + ) + self._fixed_issues.append( + {"row": action.row, "column": action.column, "value": action.new_value} + ) + self._auto_diagnose(action, issue) + return ToolResult( + action_type="FIX", success=True, data={"result": "correct"} + ), reward + + # Partial: numeric within 1% + try: + prov = float(action.new_value.strip()) + exp = float(str(issue.expected)) + rel_err = abs(prov - exp) / abs(exp) if exp != 0 else abs(prov) + if rel_err < 0.01: + reward = self._reward_engine.fix_reward( + exact=False, has_justification=bool(action.justification) + ) + self._fixed_issues.append( + {"row": action.row, "column": action.column, "value": action.new_value} + ) + self._auto_diagnose(action, issue) + return ToolResult( + action_type="FIX", success=True, data={"result": "partial_numeric"} + ), reward + except (ValueError, TypeError): + pass + + # Partial: string similarity >= 85% + sim = SequenceMatcher( + None, action.new_value.lower(), str(issue.expected).lower() + ).ratio() + if sim >= 0.85: + reward = self._reward_engine.fix_reward( + exact=False, has_justification=bool(action.justification) + ) + self._fixed_issues.append( + {"row": action.row, "column": action.column, "value": action.new_value} + ) + self._auto_diagnose(action, issue) + return ToolResult( + action_type="FIX", success=True, data={"result": "partial_string"} + ), reward + + return ToolResult( + action_type="FIX", success=True, data={"result": "wrong_value"} + ), P_WRONG_FIX + + return ToolResult( + action_type="FIX", success=True, data={"result": "no_issue_at_location"} + ), P_WRONG_FIX + + # ── Helpers ──────────────────────────────────────────────────────────── + + def _check_safety(self, action: Fix) -> tuple[bool, str]: + """Run SafetyFilter + SMTVerifier. Returns (ok, message).""" + try: + from dataforge.repairers.base import ProposedFix + from dataforge.safety.filter import SafetyContext, SafetyFilter, SafetyVerdict + from dataforge.transactions.txn import CellFix + from dataforge.verifier.smt import SMTVerifier, VerificationVerdict + + old_val = str(self._df.at[action.row, action.column]) + cell_fix = CellFix( + row=action.row, + column=action.column, + old_value=old_val, + new_value=action.new_value, + detector_id="agent", + ) + proposed = ProposedFix( + fix=cell_fix, + reason=action.justification, + confidence=0.8, + provenance="deterministic", + ) + + sf = SafetyFilter() + ctx = SafetyContext() + sr = sf.evaluate(proposed, None, ctx) + if sr.verdict == SafetyVerdict.DENY: + return False, f"Safety filter denied: {sr.reason}" + + verifier = SMTVerifier() + vr = verifier.verify(self._df, [proposed]) + if vr.verdict == VerificationVerdict.REJECT: + return False, f"SMT verifier rejected: {vr.reason}" + if vr.verdict == VerificationVerdict.UNKNOWN: + return False, f"SMT verifier returned unknown: {vr.reason}" + + return True, "Passed safety and verification" + except ImportError as exc: + return False, f"Safety/verifier dependency unavailable: {exc}" + + def _auto_diagnose(self, action: Fix, issue: Issue) -> None: + """Auto-credit diagnosis when agent fixes without diagnosing first.""" + already = any( + f["row"] == action.row and f["column"] == action.column for f in self._found_issues + ) + if not already: + self._found_issues.append( + {"row": action.row, "column": action.column, "type": issue.issue_type} + ) + + def _inject_noise(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Apply deterministic observation noise (epsilon=0.15).""" + if not self._noise_rng: + return rows + noisy = [] + for row in rows: + row_copy = dict(row) + if self._noise_rng.random() < _NOISE_EPSILON: + cols = [k for k in row_copy if k != "_row_index"] + if cols: + col = self._noise_rng.choice(cols) + val = row_copy[col] + if isinstance(val, str) and len(val) > 3: + row_copy[col] = ( + val[: -(self._noise_rng.randint(1, 3))] + if self._noise_rng.random() < 0.5 + else val.swapcase() + ) + noisy.append(row_copy) + return noisy + + def _compute_terminal(self) -> float: + """Compute terminal score.""" + fixable = [i for i in self._ground_truth if i.expected is not None] + metrics = EpisodeMetrics( + found_issues=len(self._found_issues), + total_issues=len(self._ground_truth), + fixed_issues=len(self._fixed_issues), + fixable_issues=len(fixable), + false_positives=self._false_positives, + ) + return self._reward_engine.compute_terminal_score(metrics) + + def _error_step(self, message: str) -> StepResult: + """Build error StepResult.""" + tr = ToolResult( + action_type="ERROR", success=False, error={"verdict": "error", "reason": message} + ) + self._tool_history.append(tr) + self._cumulative_reward += P_INVALID + done = self._step_count >= self._max_steps + if done: + self._is_done = True + return StepResult( + observation=DataForgeObservation( + step_budget_remaining=max(0, self._max_steps - self._step_count), + tool_usage_history=list(self._tool_history[-_TOOL_HISTORY_LIMIT:]), + latest_result=tr, + done=done, + reward=P_INVALID, + cumulative_reward=self._cumulative_reward, + scratchpad_summary=self._scratchpad.summary(), + ), + reward=P_INVALID, + done=done, + ) + + def _terminal_result(self, reward: float) -> StepResult: + """Build terminal StepResult for already-done episodes.""" + return StepResult( + observation=DataForgeObservation( + step_budget_remaining=0, + done=True, + reward=reward, + cumulative_reward=self._cumulative_reward, + scratchpad_summary=self._scratchpad.summary(), + tool_usage_history=list(self._tool_history[-_TOOL_HISTORY_LIMIT:]), + ), + reward=reward, + done=True, + ) diff --git a/dataforge/env/observation.py b/dataforge/env/observation.py new file mode 100644 index 0000000000000000000000000000000000000000..6165a1489f85e5e5059005ced3ed1c51ad8d1e77 --- /dev/null +++ b/dataforge/env/observation.py @@ -0,0 +1,61 @@ +"""Observation builder for the DataForge RL environment. + +Constructs agent-visible observations containing partial data views, +scratchpad summaries, tool results, and step budget information. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + +__all__ = ["DataForgeObservation", "ToolResult"] + + +class ToolResult(BaseModel): + """Result of a single tool-use action. + + Args: + action_type: The action type that produced this result. + success: Whether the action succeeded. + data: Action-specific result data (rows, stats, matches, etc.). + error: Structured error information if the action failed. + """ + + action_type: str + success: bool = True + data: Any = None + error: dict[str, Any] | None = None + + model_config = {"frozen": True} + + +class DataForgeObservation(BaseModel): + """Agent-visible observation returned after each environment step. + + Args: + visible_rows: Dataset rows returned by INSPECT_ROWS or reset. + detector_hints: Optional hints from detectors (partial ground truth). + scratchpad_summary: Compact summary of the agent's scratchpad. + step_budget_remaining: Steps left before auto-finalize. + tool_usage_history: Last 5 tool results for context. + latest_result: Result of the most recent action. + done: Whether the episode has ended. + reward: Step reward. + cumulative_reward: Running total reward for the episode. + metadata: Additional key-value metadata. + """ + + visible_rows: list[dict[str, Any]] | None = None + detector_hints: list[str] | None = None + scratchpad_summary: str = "" + step_budget_remaining: int = 0 + tool_usage_history: list[ToolResult] = Field(default_factory=list) + latest_result: ToolResult | None = None + done: bool = False + reward: float = 0.0 + cumulative_reward: float = 0.0 + metadata: dict[str, Any] = Field(default_factory=dict) + + model_config = {"frozen": True} diff --git a/dataforge/env/openenv_core.py b/dataforge/env/openenv_core.py new file mode 100644 index 0000000000000000000000000000000000000000..7ee0e6e7366a7471576766ee46aa8f2568d1b40f --- /dev/null +++ b/dataforge/env/openenv_core.py @@ -0,0 +1,146 @@ +"""OpenEnv-core adapter for the DataForge RL environment.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import Field + +from dataforge.env.environment import DataForgeEnv + +try: + from openenv.core.env_server import ( + Action as OpenEnvAction, + ) + from openenv.core.env_server import ( + Environment as OpenEnvEnvironment, + ) + from openenv.core.env_server import ( + Observation as OpenEnvObservation, + ) + from openenv.core.env_server import ( + create_app, + ) +except ImportError as exc: # pragma: no cover - exercised only without openenv extra + raise RuntimeError( + "The OpenEnv adapter requires the openenv extra: " + "pip install 'dataforge15[openenv]'." + ) from exc + + +class DataForgeOpenEnvAction(OpenEnvAction): + """OpenEnv action wrapper for DataForge's typed action payloads.""" + + action_type: str = Field(min_length=1) + row_indices: list[int] | None = None + column_names: list[str] | None = None + query: str | None = None + sql: str | None = None + test_type: str | None = None + test: str | None = None + column: str | None = None + threshold: float | None = None + pattern: str | None = None + regex: str | None = None + expect_match: bool | None = None + claim: str | None = None + affected_rows: list[int] | None = None + affected_columns: list[str] | None = None + root_cause_type: str | None = None + error_indices: list[int] | None = None + row: int | None = None + issue_type: str | None = None + new_value: str | None = None + proposed_value: str | None = None + justification: str | None = None + fix_type: str | None = None + + def as_dataforge_payload(self) -> dict[str, Any]: + """Return the action payload expected by ``DataForgeEnv.step``.""" + payload = self.model_dump(exclude_none=True) + payload.pop("metadata", None) + return payload + + +class DataForgeOpenEnvObservation(OpenEnvObservation): + """OpenEnv observation model mirroring DataForge's native observation.""" + + visible_rows: list[dict[str, Any]] | None = None + detector_hints: list[str] | None = None + scratchpad_summary: str = "" + step_budget_remaining: int = 0 + tool_usage_history: list[dict[str, Any]] = Field(default_factory=list) + latest_result: dict[str, Any] | None = None + cumulative_reward: float = 0.0 + + +def _to_openenv_observation(payload: dict[str, Any]) -> DataForgeOpenEnvObservation: + """Convert a native DataForge observation dictionary into OpenEnv shape.""" + return DataForgeOpenEnvObservation( + visible_rows=payload.get("visible_rows"), + detector_hints=payload.get("detector_hints"), + scratchpad_summary=str(payload.get("scratchpad_summary", "")), + step_budget_remaining=int(payload.get("step_budget_remaining", 0)), + tool_usage_history=list(payload.get("tool_usage_history") or []), + latest_result=payload.get("latest_result"), + done=bool(payload.get("done", False)), + reward=payload.get("reward"), + cumulative_reward=float(payload.get("cumulative_reward", 0.0)), + metadata=dict(payload.get("metadata") or {}), + ) + + +class DataForgeOpenEnv(OpenEnvEnvironment): + """OpenEnv-native environment wrapper.""" + + SUPPORTS_CONCURRENT_SESSIONS = True + + def __init__(self) -> None: + super().__init__() + self._env = DataForgeEnv() + self._last_observation: DataForgeOpenEnvObservation | None = None + + def reset( + self, + seed: int | None = None, + episode_id: str | None = None, + **kwargs: Any, + ) -> DataForgeOpenEnvObservation: + """Reset the wrapped DataForge environment.""" + del episode_id, kwargs + result = self._env.reset(seed=seed) + observation = _to_openenv_observation(result.observation.model_dump(mode="json")) + self._last_observation = observation + return observation + + def step( + self, + action: DataForgeOpenEnvAction, + timeout_s: float | None = None, + **kwargs: Any, + ) -> DataForgeOpenEnvObservation: + """Step the wrapped DataForge environment.""" + del timeout_s, kwargs + result = self._env.step(action.as_dataforge_payload()) + observation = _to_openenv_observation(result.observation.model_dump(mode="json")) + self._last_observation = observation + return observation + + def state(self) -> DataForgeOpenEnvObservation: + """Return the latest observation or reset lazily.""" + if self._last_observation is None: + return self.reset() + return self._last_observation + + def close(self) -> None: + """Close the wrapped environment.""" + self._env.close() + + +app = create_app( + DataForgeOpenEnv, + DataForgeOpenEnvAction, + DataForgeOpenEnvObservation, + env_name="dataforge-env", + max_concurrent_envs=64, +) diff --git a/dataforge/env/reward.py b/dataforge/env/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..598bba5d857872ff4bcb1fd44a39949b074d6a76 --- /dev/null +++ b/dataforge/env/reward.py @@ -0,0 +1,128 @@ +"""Reward engine for the DataForge RL environment. + +All constants and formulas are derived bit-for-bit from REWARD_DESIGN.md. + +Terminal score: detection_rate * 0.40 + fix_rate * 0.60 - false_positives * fp_rate +""" + +from __future__ import annotations + +from dataclasses import dataclass + +__all__ = [ + "DETECTION_WEIGHT", + "FALSE_POS_PENALTY_RATE", + "FIX_WEIGHT", + "LATE_STEP_THRESHOLD", + "P_FALSE_POS", + "P_INVALID", + "P_LATE_STEP", + "P_REINSPECT", + "P_WRONG_FIX", + "R_DIAGNOSE", + "R_EXPLORE", + "R_FIX", + "R_FIX_PARTIAL", + "R_JUSTIFY_BONUS", + "R_ROOT_CAUSE", + "R_TYPE_BONUS", + "SPAM_THRESHOLD", + "EpisodeMetrics", + "RewardEngine", +] + +# Positive rewards +R_DIAGNOSE: float = 0.10 +R_TYPE_BONUS: float = 0.05 +R_FIX: float = 0.15 +R_FIX_PARTIAL: float = 0.075 +R_JUSTIFY_BONUS: float = 0.05 +R_EXPLORE: float = 0.01 +R_ROOT_CAUSE: float = 0.10 + +# Negative penalties +P_FALSE_POS: float = -0.05 +P_WRONG_FIX: float = -0.08 +P_LATE_STEP: float = -0.02 +P_INVALID: float = -0.01 +P_REINSPECT: float = -0.01 + +# Thresholds +LATE_STEP_THRESHOLD: float = 0.80 +DETECTION_WEIGHT: float = 0.40 +FIX_WEIGHT: float = 0.60 +FALSE_POS_PENALTY_RATE: float = 0.05 +SPAM_THRESHOLD: float = 2.0 + + +@dataclass +class EpisodeMetrics: + """Accumulated metrics for terminal score computation.""" + + found_issues: int = 0 + total_issues: int = 0 + fixed_issues: int = 0 + fixable_issues: int = 0 + false_positives: int = 0 + + @property + def total_diagnoses(self) -> int: + """Total diagnosis attempts (correct + incorrect).""" + return self.found_issues + self.false_positives + + +class RewardEngine: + """Computes dense per-step and terminal rewards.""" + + def compute_terminal_score(self, metrics: EpisodeMetrics) -> float: + """Compute terminal score per REWARD_DESIGN.md formula.""" + if metrics.total_issues == 0: + return 0.0 + detection_rate = metrics.found_issues / metrics.total_issues + fix_rate = ( + metrics.fixed_issues / metrics.fixable_issues if metrics.fixable_issues > 0 else 0.0 + ) + fp_rate = FALSE_POS_PENALTY_RATE + if ( + metrics.total_issues > 0 + and metrics.total_diagnoses > SPAM_THRESHOLD * metrics.total_issues + ): + fp_rate *= 2.0 + penalty = metrics.false_positives * fp_rate + raw = detection_rate * DETECTION_WEIGHT + fix_rate * FIX_WEIGHT - penalty + return round(max(0.0, min(1.0, raw)), 4) + + def compute_late_penalty(self, step: int, max_steps: int) -> float: + """Return P_LATE_STEP if past 80% budget, else 0.0.""" + threshold = int(max_steps * LATE_STEP_THRESHOLD) + return P_LATE_STEP if step > threshold else 0.0 + + def compute_exploration_bonus( + self, + new_row_indices: set[int], + inspected_rows: set[int], + total_rows: int, + ground_truth_rows: set[int], + found_issue_rows: set[int], + ) -> float: + """Compute exploration bonus for newly-inspected rows.""" + if not new_row_indices: + return P_REINSPECT + undiscovered = sum( + 1 for r in new_row_indices if r in ground_truth_rows and r not in found_issue_rows + ) + bonus = undiscovered * R_EXPLORE + if total_rows > 0: + all_inspected = inspected_rows | new_row_indices + coverage_ratio = len(all_inspected) / total_rows + bonus += len(new_row_indices) * R_EXPLORE * 0.5 * (1.0 - coverage_ratio) + return bonus + + def diagnose_reward(self, type_match: bool) -> float: + """Reward for correct diagnosis.""" + return R_DIAGNOSE + (R_TYPE_BONUS if type_match else 0.0) + + def fix_reward(self, exact: bool, has_justification: bool) -> float: + """Reward for correct fix.""" + reward = R_FIX if exact else R_FIX_PARTIAL + return reward + (R_JUSTIFY_BONUS if has_justification else 0.0) diff --git a/dataforge/env/server.py b/dataforge/env/server.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ab125a4341d46c0345ba2691d88cbedbd226c1 --- /dev/null +++ b/dataforge/env/server.py @@ -0,0 +1,175 @@ +"""FastAPI server for the DataForge RL environment. + +Provides OpenEnv-compatible HTTP endpoints: + POST /reset — Start a new episode + POST /step — Execute an action + GET /state — Return current state snapshot + POST /close — No-op shutdown + GET /health — Liveness check + GET /metadata — Environment metadata + GET /schema — Action/observation JSON schemas +""" + +from __future__ import annotations + +import logging +import os +from threading import RLock +from typing import Any + +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from pydantic import TypeAdapter + +from dataforge.agent.tool_actions import Action +from dataforge.env.environment import DataForgeEnv, EnvState +from dataforge.env.observation import DataForgeObservation +from dataforge.http.problem import problem_exception_handler +from dataforge.observability import configure_fastapi_observability + +logger = logging.getLogger("dataforge.env.server") + + +def _build_cors_origins() -> list[str]: + """Build the explicit OpenEnv CORS allowlist from the environment.""" + raw_origins = os.environ.get("DATAFORGE_OPENENV_ORIGINS", "") + return [origin.strip() for origin in raw_origins.split(",") if origin.strip()] + + +def _build_cors_origin_regex() -> str | None: + """Allow local browser development only when explicitly enabled.""" + if os.environ.get("DATAFORGE_OPENENV_DEV") != "1": + return None + return r"^http://(?:localhost|127(?:\.\d{1,3}){3})(?::\d+)?$" + + +app = FastAPI( + title="DataForge Environment", + description="OpenEnv-compatible RL environment for data-quality repair.", + version="0.1.0", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=_build_cors_origins(), + allow_origin_regex=_build_cors_origin_regex(), + allow_credentials=False, + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["*"], +) +app.add_exception_handler(HTTPException, problem_exception_handler) +configure_fastapi_observability(app, service_name="dataforge-openenv") + +_registry_lock = RLock() +_default_env = DataForgeEnv() +_sessions: dict[str, DataForgeEnv] = {} + + +def _get_env(episode_id: str | None) -> DataForgeEnv: + """Resolve an environment by episode id, preserving legacy no-id behavior.""" + if not episode_id: + return _default_env + with _registry_lock: + try: + return _sessions[episode_id] + except KeyError as exc: + raise HTTPException( + status_code=404, + detail={"error": "episode_not_found", "episode_id": episode_id}, + ) from exc + + +def _remember_env(env: DataForgeEnv, episode_id: str) -> None: + """Register a session and update the legacy default environment.""" + global _default_env + with _registry_lock: + _sessions[episode_id] = env + _default_env = env + + +@app.post("/reset") +async def reset(seed: int | None = None) -> dict[str, Any]: + """Reset the environment for a new episode.""" + env = DataForgeEnv() + result = env.reset(seed=seed) + episode_id = str(result.info["episode_id"]) + _remember_env(env, episode_id) + return result.model_dump(mode="json") + + +@app.post("/step") +async def step(action: dict[str, Any]) -> dict[str, Any]: + """Execute one agent action.""" + action_payload = dict(action) + raw_episode_id = action_payload.pop("episode_id", None) + episode_id = str(raw_episode_id) if raw_episode_id else None + result = _get_env(episode_id).step(action_payload) + return result.model_dump(mode="json") + + +@app.get("/state") +async def state(episode_id: str | None = None) -> dict[str, Any]: + """Return current environment state snapshot.""" + result = _get_env(episode_id).state() + return result.model_dump(mode="json") + + +@app.post("/close") +async def close(request: Request, episode_id: str | None = None) -> dict[str, Any]: + """No-op close endpoint for OpenEnv compatibility.""" + body_episode_id: str | None = None + if episode_id is None: + try: + payload = await request.json() + except Exception: + payload = None + if isinstance(payload, dict) and payload.get("episode_id"): + body_episode_id = str(payload["episode_id"]) + + target_episode_id = episode_id or body_episode_id + env = _get_env(target_episode_id) + env.close() + if target_episode_id: + with _registry_lock: + _sessions.pop(target_episode_id, None) + return {"status": "closed", "episode_id": target_episode_id} + + +@app.get("/health") +async def health() -> dict[str, Any]: + """Liveness check.""" + return {"status": "healthy", "environment": "dataforge-env"} + + +@app.get("/metadata") +async def metadata() -> dict[str, Any]: + """Environment metadata for OpenEnv discovery.""" + return { + "name": "dataforge-env", + "version": "0.1.0", + "description": ( + "DataForge RL Environment — agents learn to detect, diagnose, " + "and repair data-quality issues in tabular datasets." + ), + "action_types": [ + "INSPECT_ROWS", + "SQL_QUERY", + "STAT_TEST", + "PATTERN_MATCH", + "HYPOTHESIS", + "ROOT_CAUSE", + "DIAGNOSE", + "FIX", + ], + } + + +@app.get("/schema") +async def schema() -> dict[str, Any]: + """Return JSON schemas for action and observation models.""" + action_adapter: TypeAdapter[Action] = TypeAdapter(Action) + return { + "action": action_adapter.json_schema(), + "observation": DataForgeObservation.model_json_schema(), + "state": EnvState.model_json_schema(), + } diff --git a/dataforge/evaluation_contract.py b/dataforge/evaluation_contract.py new file mode 100644 index 0000000000000000000000000000000000000000..1045c0dfdb081634c3a7c16c8cbfbed4014609ee --- /dev/null +++ b/dataforge/evaluation_contract.py @@ -0,0 +1,76 @@ +"""Public evaluation evidence models for DataForge repair releases.""" + +from __future__ import annotations + +import hashlib +import json +from typing import Any, Literal + +from pydantic import BaseModel, Field + +InferabilityLabel = Literal[ + "deterministic_normalization", + "context_derivable", + "external_reference_required", + "not_inferable_from_prompt", +] +PROMOTION_SLICE: InferabilityLabel = "deterministic_normalization" +ABSTENTION_SLICES = frozenset({"external_reference_required", "not_inferable_from_prompt"}) +AUXILIARY_SLICES = frozenset({"context_derivable"}) +PromotionStatus = Literal[ + "diagnostic_only", + "diagnostic_promoted", + "quality_improved_verified", + "public_quality_milestone", + "rejected", +] + + +class EvaluationTaskV2(BaseModel): + """One auditable, source-stable model grading task. + + Ground truth is retained for local grading but excluded from normal JSON + serialization so prompts and public reports cannot accidentally leak labels. + """ + + schema_version: Literal["evaluation_task_v2"] = "evaluation_task_v2" + task_id: str = Field(min_length=1) + prompt_hash: str = Field(min_length=64, max_length=64) + dataset_sha: str = Field(min_length=1) + split_id: str = Field(min_length=1) + inferability: InferabilityLabel + prompt: dict[str, Any] + allowed_columns: list[str] = Field(min_length=1) + valid_rows: list[int] = Field(min_length=1) + provenance: dict[str, Any] + hidden_ground_truth: list[dict[str, Any]] = Field(default_factory=list, exclude=True) + + model_config = {"frozen": True} + + +class ReleaseEvidenceV2(BaseModel): + """Serializable release-gate evidence for model and benchmark promotion.""" + + schema_version: Literal["release_evidence_v2"] = "release_evidence_v2" + model_repo: str = Field(min_length=1) + model_sha: str = Field(min_length=1) + dataset_repo: str = Field(min_length=1) + dataset_sha: str = Field(min_length=1) + strict_macro_f1: float = Field(ge=0.0, le=1.0) + canonicalized_macro_f1: float = Field(ge=0.0, le=1.0) + parse_success_rate: float = Field(ge=0.0, le=1.0) + schema_case_error_count: int = Field(ge=0) + promotion_slice: InferabilityLabel = PROMOTION_SLICE + slice_scores: dict[InferabilityLabel, dict[str, float | int]] = Field(default_factory=dict) + inferability_slice_scores: dict[InferabilityLabel, float] = Field(default_factory=dict) + package_versions: dict[str, str] = Field(default_factory=dict) + promotion_status: PromotionStatus + gate_failures: list[str] = Field(default_factory=list) + + model_config = {"frozen": True} + + +def prompt_sha256(prompt: dict[str, Any]) -> str: + """Hash a prompt payload with stable JSON serialization.""" + encoded = json.dumps(prompt, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() diff --git a/dataforge/fixtures/hospital_10rows.csv b/dataforge/fixtures/hospital_10rows.csv new file mode 100644 index 0000000000000000000000000000000000000000..8b4373908d9b876f07c5fefc7420077e17323040 --- /dev/null +++ b/dataforge/fixtures/hospital_10rows.csv @@ -0,0 +1,11 @@ +provider_number,hospital_name,city,state,zip_code,phone_number,rating,mortality_rate,readmission_rate,er_wait_time +PRV001,General Hospital,Springfield,IL,62701,2175550101,4.2,0.023,0.145,28 +PRV002,St. Mary Medical Center,Chicago,IL,60601,3125550202,3.8,0.031,0.162,35 +PRV001,Springfield Medical,Springfield,IL,62701,2175550303,4.5,0.019,0.138,22 +PRV003,Mercy Hospital,Peoria,IL,61602,3095550404,3.5,0.028,0.158,31 +PRV004,Northwestern Memorial,Chicago,IL,60611,not available,4.1,0.025,0.149,26 +PRV005,Rush University MC,Chicago,IL,60612,3125550606,45.0,0.022,0.141,29 +PRV006,Advocate Christ,Oak Lawn,IL,60453,7085550707,3.9,0.027,0.155,33 +PRV007,Loyola University MC,Maywood,IL,60153,7085550808,4.3,0.020,0.142,25 +PRV008,Presence St. Joseph,Joliet,IL,60435,8155550909,4.0,0.026,0.151,30 +PRV009,Edward Hospital,Naperville,IL,60540,6305551010,3.7,0.029,0.160,34 diff --git a/dataforge/fixtures/hospital_schema.yaml b/dataforge/fixtures/hospital_schema.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d6c8b0d62cd858889f9dd346d148ba0e3d8c9fe --- /dev/null +++ b/dataforge/fixtures/hospital_schema.yaml @@ -0,0 +1,17 @@ +# Hospital dataset schema for DataForge profile command. + +columns: + provider_number: str + hospital_name: str + city: str + state: str + zip_code: str + phone_number: str + rating: float + mortality_rate: float + readmission_rate: float + er_wait_time: int + +functional_dependencies: + - determinant: [provider_number] + dependent: hospital_name diff --git a/dataforge/http/__init__.py b/dataforge/http/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8656f2f2bcffd121d2b9599a236cb6f39c0993f9 --- /dev/null +++ b/dataforge/http/__init__.py @@ -0,0 +1 @@ +"""HTTP helpers shared by DataForge backend surfaces.""" diff --git a/dataforge/http/problem.py b/dataforge/http/problem.py new file mode 100644 index 0000000000000000000000000000000000000000..1985d55b5c7289dfd2412803198da7a10da3fb16 --- /dev/null +++ b/dataforge/http/problem.py @@ -0,0 +1,99 @@ +"""RFC 9457 problem details helpers for FastAPI surfaces.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from fastapi import HTTPException, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel, ConfigDict, Field + + +class ProblemDetail(BaseModel): + """RFC 9457 problem detail response with extension members.""" + + type: str = Field(default="about:blank") + title: str + status: int + detail: str + instance: str | None = None + + model_config = ConfigDict(strict=True, extra="allow") + + +def problem_body( + *, + status: int, + title: str, + detail: str, + type_: str = "about:blank", + instance: str | None = None, + **extensions: Any, +) -> dict[str, Any]: + """Build a problem details JSON object.""" + body = ProblemDetail( + type=type_, + title=title, + status=status, + detail=detail, + instance=instance, + **extensions, + ) + return body.model_dump(mode="json", exclude_none=True) + + +def problem_response( + *, + status: int, + title: str, + detail: str, + type_: str = "about:blank", + instance: str | None = None, + headers: Mapping[str, str] | None = None, + **extensions: Any, +) -> JSONResponse: + """Return an RFC 9457 JSON response.""" + return JSONResponse( + status_code=status, + content=problem_body( + status=status, + title=title, + detail=detail, + type_=type_, + instance=instance, + **extensions, + ), + headers=headers, + media_type="application/problem+json", + ) + + +async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: + """Normalize FastAPI HTTPException values into problem details.""" + raw_detail = exc.detail + extensions: dict[str, Any] = {} + if isinstance(raw_detail, dict): + error_code = str(raw_detail.get("error", "http_error")) + message = str(raw_detail.get("message") or raw_detail.get("detail") or error_code) + extensions.update(raw_detail) + else: + error_code = "http_error" + message = str(raw_detail) + + return problem_response( + status=exc.status_code, + type_=f"https://dataforge.local/problems/{error_code}", + title=error_code.replace("_", " ").title(), + detail=message, + instance=str(request.url.path), + headers=exc.headers, + **extensions, + ) + + +async def problem_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Adapter with the broad exception signature Starlette expects.""" + if isinstance(exc, HTTPException): + return await http_exception_handler(request, exc) + raise exc diff --git a/dataforge/integrations/dbt.py b/dataforge/integrations/dbt.py new file mode 100644 index 0000000000000000000000000000000000000000..1a564b05758b231b3df63c2a70399fe06c248ae6 --- /dev/null +++ b/dataforge/integrations/dbt.py @@ -0,0 +1 @@ +"""The dbt integration lives in the separate ``dataforge15-dbt`` package.""" diff --git a/dataforge/observability.py b/dataforge/observability.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac0c41bbee20b70419064c699c3ba960e700f18 --- /dev/null +++ b/dataforge/observability.py @@ -0,0 +1,76 @@ +"""Optional OpenTelemetry hooks for DataForge backend surfaces.""" + +from __future__ import annotations + +import os +from collections.abc import Iterator +from contextlib import contextmanager, nullcontext +from importlib import import_module +from typing import Any + +_SENSITIVE_ATTR_FRAGMENTS = ("authorization", "cookie", "token", "key", "secret", "password") + + +def _otel_enabled() -> bool: + """Return whether optional OpenTelemetry instrumentation is enabled.""" + return os.environ.get("DATAFORGE_OTEL_ENABLED", "").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + +def _safe_attrs(attributes: dict[str, Any]) -> dict[str, str | int | float | bool]: + """Keep only scalar, non-sensitive telemetry attributes.""" + safe: dict[str, str | int | float | bool] = {} + for key, value in attributes.items(): + lowered = key.lower() + if any(fragment in lowered for fragment in _SENSITIVE_ATTR_FRAGMENTS): + continue + if lowered in {"row_values", "rows", "payload", "source_bytes", "csv"}: + continue + if isinstance(value, str | int | float | bool): + safe[key] = value + return safe + + +def configure_fastapi_observability(app: Any, *, service_name: str) -> bool: + """Instrument a FastAPI app when OpenTelemetry is explicitly enabled.""" + if not _otel_enabled(): + return False + try: + fastapi_instrumentation = import_module("opentelemetry.instrumentation.fastapi") + trace_module = import_module("opentelemetry.trace") + except ImportError: + return False + + app.state.dataforge_service_name = service_name + fastapi_instrumentation.FastAPIInstrumentor.instrument_app( + app, + tracer_provider=trace_module.get_tracer_provider(), + excluded_urls="/api/docs,/docs,/redoc,/openapi.json", + ) + return True + + +@contextmanager +def repair_stage_span(stage: str, **attributes: Any) -> Iterator[None]: + """Create a repair-stage span when OpenTelemetry is available.""" + if not _otel_enabled(): + with nullcontext(): + yield + return + + try: + trace_module = import_module("opentelemetry.trace") + except ImportError: + with nullcontext(): + yield + return + + tracer = trace_module.get_tracer("dataforge.repair") + with tracer.start_as_current_span(stage) as span: + for key, value in _safe_attrs(attributes).items(): + span.set_attribute(key, value) + yield diff --git a/dataforge/py.typed b/dataforge/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/dataforge/py.typed @@ -0,0 +1 @@ + diff --git a/dataforge/release/__init__.py b/dataforge/release/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20e35672c4cb5fbb81e2ee645056a88eda0b5d06 --- /dev/null +++ b/dataforge/release/__init__.py @@ -0,0 +1,2 @@ +"""Release verification helpers for DataForge.""" + diff --git a/dataforge/release/doctor.py b/dataforge/release/doctor.py new file mode 100644 index 0000000000000000000000000000000000000000..398262eba848486af2b75975ec7911c60e6f68c0 --- /dev/null +++ b/dataforge/release/doctor.py @@ -0,0 +1,201 @@ +"""Release doctor checks for DataForge public-surface gates.""" + +from __future__ import annotations + +import json +import shutil +import socket +import subprocess +import sys +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +EXPECTED_HF_USER = "Praneshrajan15" +DEFAULT_KAGGLE_CREDENTIALS = Path.home() / ".kaggle" / "credentials.json" +STALE_KAGGLE_JSON = Path.home() / ".kaggle" / "kaggle.json" +DATAFORGE_DOMAIN = "dataforge.dev" + + +@dataclass(frozen=True) +class DoctorCheck: + """One release doctor check result.""" + + name: str + ok: bool + detail: str + metadata: dict[str, Any] + + +@dataclass(frozen=True) +class DoctorReport: + """Machine-readable release doctor report.""" + + ok: bool + checks: list[DoctorCheck] + secrets_printed: bool = False + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable report.""" + return asdict(self) + + +def _check_hugging_face() -> DoctorCheck: + try: + from huggingface_hub import HfApi, get_token + except ImportError as exc: + return DoctorCheck("hugging_face", False, f"huggingface_hub missing: {exc}", {}) + + token_present = bool(get_token()) + if not token_present: + return DoctorCheck( + "hugging_face", + False, + "No cached Hugging Face token found.", + {"expected_user": EXPECTED_HF_USER, "token_present": False}, + ) + try: + info = HfApi(token=get_token()).whoami() + except Exception as exc: + return DoctorCheck( + "hugging_face", + False, + f"Could not resolve Hugging Face identity: {exc}", + {"expected_user": EXPECTED_HF_USER, "token_present": True}, + ) + user = str(info.get("name", "")) + return DoctorCheck( + "hugging_face", + user == EXPECTED_HF_USER, + "Authenticated with expected Hugging Face user." + if user == EXPECTED_HF_USER + else f"Authenticated as {user!r}, expected {EXPECTED_HF_USER!r}.", + {"user": user, "expected_user": EXPECTED_HF_USER, "token_present": True}, + ) + + +def _load_kaggle_oauth(path: Path) -> dict[str, Any]: + if path.name == "kaggle.json": + raise RuntimeError( + f"Refusing to read stale legacy Kaggle API key file: {path}. " + f"Use OAuth credentials at {DEFAULT_KAGGLE_CREDENTIALS}." + ) + if not path.exists(): + raise RuntimeError(f"Missing Kaggle OAuth credentials: {path}") + payload = json.loads(path.read_text(encoding="utf-8-sig")) + if not isinstance(payload, dict): + raise RuntimeError("Kaggle OAuth credentials must be a JSON object.") + required = {"refresh_token", "access_token", "access_token_expiration", "username", "scopes"} + missing = sorted(required - set(payload)) + if missing: + raise RuntimeError("Kaggle OAuth credentials missing fields: " + ", ".join(missing)) + if not isinstance(payload.get("username"), str) or not payload["username"]: + raise RuntimeError("Kaggle OAuth credentials are missing username.") + scopes = payload.get("scopes") + if not isinstance(scopes, list) or not scopes: + raise RuntimeError("Kaggle OAuth credentials are missing scopes.") + return payload + + +def _check_kaggle_oauth(credentials_path: Path = DEFAULT_KAGGLE_CREDENTIALS) -> DoctorCheck: + try: + payload = _load_kaggle_oauth(credentials_path) + except Exception as exc: + return DoctorCheck("kaggle_oauth", False, str(exc), {"credential_path": str(credentials_path)}) + return DoctorCheck( + "kaggle_oauth", + True, + "Kaggle OAuth credentials are present and legacy key is ignored.", + { + "credential_path": str(credentials_path), + "username": payload["username"], + "scopes_count": len(payload["scopes"]), + "legacy_kaggle_json_exists": STALE_KAGGLE_JSON.exists(), + "legacy_kaggle_json_used": False, + "tokens_printed": False, + }, + ) + + +def _check_cloudflare() -> DoctorCheck: + npx = shutil.which("npx") or shutil.which("npx.cmd") + if npx is None: + return DoctorCheck("cloudflare", False, "npx/wrangler not available on PATH.", {}) + command = [npx, "wrangler", "whoami"] + try: + result = subprocess.run( + command, + capture_output=True, + text=True, + encoding="utf-8", + errors="replace", + timeout=60, + check=False, + ) + except FileNotFoundError: + return DoctorCheck("cloudflare", False, "npx/wrangler not available on PATH.", {}) + except subprocess.TimeoutExpired: + return DoctorCheck("cloudflare", False, "wrangler whoami timed out.", {}) + + output = result.stdout + result.stderr + ok = result.returncode == 0 and "logged in" in output.lower() + has_route_scope = "workers_routes (write)" in output + return DoctorCheck( + "cloudflare", + ok and has_route_scope, + "Wrangler OAuth is logged in and can write Worker routes." + if ok and has_route_scope + else "Wrangler is not logged in with the required Workers route scope.", + { + "wrangler_available": result.returncode == 0, + "logged_in": ok, + "workers_routes_write": has_route_scope, + "command": " ".join(command), + }, + ) + + +def _check_domain() -> DoctorCheck: + try: + _, _, ips = socket.gethostbyname_ex(DATAFORGE_DOMAIN) + except OSError: + ips = [] + return DoctorCheck( + "dataforge_domain", + True, + "Domain check completed; route activation is verified by playground deploy checks.", + {"domain": DATAFORGE_DOMAIN, "a_records_seen": len(ips), "route": "dataforge.dev/playground*"}, + ) + + +def run_doctor(*, kaggle_credentials: Path = DEFAULT_KAGGLE_CREDENTIALS) -> DoctorReport: + """Run all local release doctor checks.""" + checks = [ + _check_hugging_face(), + _check_kaggle_oauth(kaggle_credentials), + _check_cloudflare(), + _check_domain(), + ] + return DoctorReport(ok=all(check.ok for check in checks), checks=checks) + + +def main(argv: list[str] | None = None) -> int: + """Script entrypoint used by CI and local release work.""" + import argparse + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--json", action="store_true", help="Print JSON instead of text.") + parser.add_argument("--kaggle-credentials", type=Path, default=DEFAULT_KAGGLE_CREDENTIALS) + args = parser.parse_args(argv) + report = run_doctor(kaggle_credentials=args.kaggle_credentials) + if args.json: + print(json.dumps(report.to_dict(), indent=2, sort_keys=True)) + else: + for check in report.checks: + status = "ok" if check.ok else "fail" + print(f"{status:4} {check.name}: {check.detail}") + return 0 if report.ok else 2 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/dataforge/repair_contract.py b/dataforge/repair_contract.py new file mode 100644 index 0000000000000000000000000000000000000000..29da3951c39b2e780113e28515ba9e044dbb66ce --- /dev/null +++ b/dataforge/repair_contract.py @@ -0,0 +1,468 @@ +"""Canonical prompt, parsing, and scoring contract for DataForge repairs.""" + +from __future__ import annotations + +import json +import re +from collections import Counter, OrderedDict +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Literal, Protocol + +from pydantic import BaseModel, Field, ValidationError, model_validator + +CONTRACT_VERSION_V1 = "repair_contract_v1" +CONTRACT_VERSION_V2 = "repair_contract_v2" +CONTRACT_VERSION = CONTRACT_VERSION_V2 + +SYSTEM_PROMPT = ( + "You repair tabular data by proposing exact cell replacements. " + "Rows must be absolute row ids from valid_rows and columns must exactly match one of " + "the allowed_columns values. " + "Use only the provided dirty target rows and optional context rows. " + "Return strict JSON only in this object shape: " + '{"action":"submit_repairs","repairs":[{"row":0,"column":"Column",' + '"new_value":"value","reason":"why"}]}. ' + 'Use {"action":"finish","repairs":[]} when no cells should be changed. ' + "Do not wrap the JSON in markdown code fences." +) + +_JSON_FENCE_RE = re.compile(r"```(?:json)?\s*\n?(.*?)\n?\s*```", re.DOTALL) + + +class RepairLike(Protocol): + """Minimal shape shared by repair objects across DataForge packages.""" + + @property + def row(self) -> int: ... + + @property + def column(self) -> str: ... + + @property + def new_value(self) -> str: ... + + @property + def reason(self) -> str: ... + + +class TruthLike(Protocol): + """Minimal shape shared by ground-truth cell objects across packages.""" + + @property + def row(self) -> int: ... + + @property + def column(self) -> str: ... + + @property + def clean_value(self) -> str: ... + + +class RepairFix(BaseModel): + """One exact cell replacement proposed by a repair agent.""" + + row: int = Field(ge=0) + column: str = Field(min_length=1) + new_value: str + reason: str = Field(default="repair proposal", min_length=1) + + model_config = {"frozen": True} + + +class RepairAction(BaseModel): + """The only JSON action shape accepted by the repair contract.""" + + action: Literal["submit_repairs", "finish"] + repairs: list[RepairFix] = Field(default_factory=list) + + model_config = {"frozen": True} + + @model_validator(mode="after") + def _finish_must_be_empty(self) -> RepairAction: + if self.action == "finish" and self.repairs: + raise ValueError("finish actions must not include repairs") + return self + + +class RepairParseResult(BaseModel): + """Parsed repair action plus diagnostics suitable for release gates.""" + + ok: bool + action: RepairAction | None = None + error_kind: ( + Literal[ + "parse_failure", + "truncated_json", + "schema_error", + "invalid_column", + "invalid_row", + ] + | None + ) = None + error_message: str | None = None + diagnostics: dict[str, int | str | bool] = Field(default_factory=dict) + + model_config = {"frozen": True} + + +class RepairScore(BaseModel): + """Exact-match cell repair metrics.""" + + tp: int = Field(ge=0) + fp: int = Field(ge=0) + fn: int = Field(ge=0) + precision: float = Field(ge=0.0, le=1.0) + recall: float = Field(ge=0.0, le=1.0) + f1: float = Field(ge=0.0, le=1.0) + + model_config = {"frozen": True} + + +def _as_jsonable_rows(rows: Sequence[Mapping[str, Any]]) -> list[dict[str, str]]: + """Return rows as stable string-valued mappings while preserving ``_row``.""" + rendered: list[dict[str, str]] = [] + for row in rows: + rendered_row: dict[str, str] = {} + for key, value in row.items(): + rendered_row[str(key)] = str(value) + rendered.append(rendered_row) + return rendered + + +def _valid_rows_from_target_rows(target_rows: Sequence[Mapping[str, Any]]) -> list[int]: + """Return absolute row ids from target rows, falling back to local ids for legacy rows.""" + valid_rows: list[int] = [] + for fallback_row, row in enumerate(target_rows): + raw_row = row.get("_row", fallback_row) + valid_rows.append(int(str(raw_row))) + return valid_rows + + +def build_repair_user_payload( + *, + schema_summary: Mapping[str, Any], + target_rows: Sequence[Mapping[str, Any]], + context_rows: Sequence[Mapping[str, Any]] = (), + allowed_columns: Sequence[str], + valid_rows: Sequence[int] | None = None, + label_source: str | None = None, + dataset_note: str | None = None, + metadata: Mapping[str, Any] | None = None, + contract_version: str = CONTRACT_VERSION, +) -> dict[str, Any]: + """Build the canonical user payload for repair SFT and evaluation.""" + payload: dict[str, Any] = { + "contract_version": contract_version, + "schema_summary": dict(schema_summary), + "allowed_columns": list(allowed_columns), + "valid_rows": list(valid_rows) + if valid_rows is not None + else _valid_rows_from_target_rows(target_rows), + "target_rows": _as_jsonable_rows(target_rows), + "context_rows": _as_jsonable_rows(context_rows), + } + if label_source is not None: + payload["label_source"] = label_source + if dataset_note is not None: + payload["dataset_note"] = dataset_note + if metadata is not None: + payload["metadata"] = dict(metadata) + return payload + + +def render_repair_messages( + *, + schema_summary: Mapping[str, Any], + target_rows: Sequence[Mapping[str, Any]], + allowed_columns: Sequence[str], + valid_rows: Sequence[int] | None = None, + context_rows: Sequence[Mapping[str, Any]] = (), + label_source: str | None = None, + dataset_note: str | None = None, + metadata: Mapping[str, Any] | None = None, + repairs: Sequence[RepairLike] | None = None, + contract_version: str = CONTRACT_VERSION, +) -> list[dict[str, str]]: + """Render canonical chat messages for a repair task. + + When ``repairs`` is ``None``, only system and user messages are returned. + When repairs are provided, an assistant message is appended for SFT. + """ + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": json.dumps( + build_repair_user_payload( + schema_summary=schema_summary, + target_rows=target_rows, + context_rows=context_rows, + allowed_columns=allowed_columns, + valid_rows=valid_rows, + label_source=label_source, + dataset_note=dataset_note, + metadata=metadata, + contract_version=contract_version, + ), + sort_keys=True, + separators=(",", ":"), + ), + }, + ] + if repairs is not None: + repair_fixes = [ + RepairFix( + row=repair.row, + column=repair.column, + new_value=repair.new_value, + reason=repair.reason, + ) + for repair in repairs + ] + messages.append( + { + "role": "assistant", + "content": json.dumps( + RepairAction( + action="submit_repairs" if repair_fixes else "finish", + repairs=repair_fixes, + ).model_dump(mode="json"), + sort_keys=True, + separators=(",", ":"), + ), + } + ) + return messages + + +def _strip_fence(text: str) -> str: + """Strip a single markdown JSON fence if the model returned one.""" + stripped = text.strip() + match = _JSON_FENCE_RE.search(stripped) + return match.group(1).strip() if match else stripped + + +def extract_json_payload(text: str) -> object: + """Extract the first complete JSON object or array from model text.""" + clean_text = _strip_fence(text) + decoder = json.JSONDecoder() + saw_start = False + for offset, char in enumerate(clean_text): + if char not in "[{": + continue + saw_start = True + try: + payload, _end = decoder.raw_decode(clean_text[offset:]) + except json.JSONDecodeError: + continue + if isinstance(payload, dict | list): + return payload + if saw_start: + raise ValueError("truncated_json") + raise ValueError("parse_failure") + + +def _schema_case_error(column: str, allowed_columns: set[str]) -> bool: + """Return whether ``column`` only differs from an allowed column by case.""" + return column.lower() in {allowed.lower() for allowed in allowed_columns} + + +def parse_repair_action( + text: str, + *, + allowed_columns: Iterable[str] | None = None, + valid_rows: Iterable[int] | None = None, + require_explicit_action: bool = False, +) -> RepairParseResult: + """Parse model text into a canonical repair action without raising. + + By default this remains permissive enough to read legacy v1 artifacts. Pass + ``allowed_columns``, ``valid_rows``, and ``require_explicit_action=True`` for + the v2 release-gate contract. + """ + try: + payload = extract_json_payload(text) + except ValueError as exc: + if str(exc) == "truncated_json": + return RepairParseResult( + ok=False, + error_kind="truncated_json", + error_message=str(exc), + ) + return RepairParseResult(ok=False, error_kind="parse_failure", error_message=str(exc)) + + diagnostics: dict[str, int | str | bool] = {} + if isinstance(payload, list): + if require_explicit_action: + return RepairParseResult( + ok=False, + error_kind="schema_error", + error_message="repair payload must include an explicit action", + ) + payload = {"action": "submit_repairs", "repairs": payload} + if not isinstance(payload, dict): + return RepairParseResult( + ok=False, + error_kind="schema_error", + error_message="repair payload must be a JSON object or array", + ) + if "repairs" in payload and "action" not in payload: + if require_explicit_action: + return RepairParseResult( + ok=False, + error_kind="schema_error", + error_message="repair payload must include an explicit action", + ) + payload = {**payload, "action": "submit_repairs"} + try: + action = RepairAction.model_validate(payload) + except ValidationError as exc: + return RepairParseResult(ok=False, error_kind="schema_error", error_message=str(exc)) + + normalized_repairs = normalize_fixes(action.repairs) + duplicate_count = len(action.repairs) - len(normalized_repairs) + if duplicate_count: + diagnostics["duplicate_cell_count"] = duplicate_count + action = RepairAction(action=action.action, repairs=normalized_repairs) + + if allowed_columns is not None: + allowed = set(allowed_columns) + for repair in action.repairs: + if repair.column in allowed: + continue + diagnostics["invalid_column"] = repair.column + diagnostics["schema_case_error"] = _schema_case_error(repair.column, allowed) + return RepairParseResult( + ok=False, + error_kind="invalid_column", + error_message=f"column {repair.column!r} is not in allowed_columns", + diagnostics=diagnostics, + ) + + if valid_rows is not None: + rows = {int(row) for row in valid_rows} + for repair in action.repairs: + if repair.row in rows: + continue + diagnostics["invalid_row"] = repair.row + return RepairParseResult( + ok=False, + error_kind="invalid_row", + error_message=f"row {repair.row} is not in valid_rows", + diagnostics=diagnostics, + ) + + return RepairParseResult(ok=True, action=action, diagnostics=diagnostics) + + +def normalize_fixes(fixes: Iterable[RepairLike]) -> list[RepairFix]: + """Collapse repairs to one final prediction per cell using last-write-wins.""" + by_cell: OrderedDict[tuple[int, str], RepairFix] = OrderedDict() + for fix in fixes: + normalized = RepairFix( + row=fix.row, + column=fix.column, + new_value=fix.new_value, + reason=fix.reason, + ) + key = (normalized.row, normalized.column) + if key in by_cell: + del by_cell[key] + by_cell[key] = normalized + return list(by_cell.values()) + + +def canonicalize_cell_value(value: str) -> str: + """Return a diagnostics-only canonical value for fuzzy F1 reporting.""" + return " ".join(str(value).strip().casefold().split()) + + +def _strict_cell_value(value: str) -> str: + """Return the official exact-match value normalization.""" + return str(value).rstrip() + + +def score_repair_fixes( + ground_truth: Iterable[TruthLike], + fixes: Iterable[RepairLike], + *, + canonicalize_values: bool = False, +) -> RepairScore: + """Score repairs by exact row, column, and string value match.""" + normalized = normalize_fixes(fixes) + value_fn = canonicalize_cell_value if canonicalize_values else _strict_cell_value + expected = {(cell.row, cell.column): value_fn(str(cell.clean_value)) for cell in ground_truth} + matched: set[tuple[int, str]] = set() + tp = 0 + fp = 0 + for fix in normalized: + key = (fix.row, fix.column) + expected_value = expected.get(key) + if expected_value is not None and value_fn(fix.new_value) == expected_value: + tp += 1 + matched.add(key) + else: + fp += 1 + fn = len(expected) - len(matched) + precision = tp / (tp + fp) if (tp + fp) else 0.0 + recall = tp / (tp + fn) if (tp + fn) else 0.0 + f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0 + return RepairScore( + tp=tp, + fp=fp, + fn=fn, + precision=round(precision, 4), + recall=round(recall, 4), + f1=round(f1, 4), + ) + + +def score_repair_fixes_canonicalized( + ground_truth: Iterable[TruthLike], + fixes: Iterable[RepairLike], +) -> RepairScore: + """Diagnostics-only F1 after conservative value canonicalization.""" + return score_repair_fixes(ground_truth, fixes, canonicalize_values=True) + + +def repair_failure_taxonomy( + *, + ground_truth: Iterable[TruthLike], + fixes: Iterable[RepairLike], + allowed_columns: Iterable[str], + valid_rows: Iterable[int], +) -> dict[str, int]: + """Classify exact-match failures without changing official scoring.""" + columns = set(allowed_columns) + lowercase_columns = {column.lower(): column for column in columns} + rows = set(valid_rows) + truth_map = {(cell.row, cell.column): str(cell.clean_value) for cell in ground_truth} + raw_fixes = list(fixes) + normalized_fixes = normalize_fixes(raw_fixes) + predictions = {(fix.row, fix.column): fix.new_value for fix in normalized_fixes} + counts: Counter[str] = Counter() + duplicate_count = len(raw_fixes) - len(normalized_fixes) + if duplicate_count: + counts["duplicate_cell"] += duplicate_count + + for fix in normalized_fixes: + key = (fix.row, fix.column) + if fix.column not in columns: + if fix.column.lower() in lowercase_columns: + counts["schema_case_error"] += 1 + else: + counts["wrong_cell"] += 1 + continue + if fix.row not in rows: + counts["wrong_cell"] += 1 + continue + if key not in truth_map: + counts["overrepair"] += 1 + continue + if truth_map[key] != fix.new_value: + counts["wrong_value"] += 1 + + for key in truth_map: + if key not in predictions: + counts["missed_repair"] += 1 + return {kind: count for kind, count in sorted(counts.items()) if count} diff --git a/dataforge/repairers/__init__.py b/dataforge/repairers/__init__.py index 767f9403e304353b524528b9943a3fc12a932d29..4d90b71e3764d8792cd8d8d005b79a0d7b39607f 100644 --- a/dataforge/repairers/__init__.py +++ b/dataforge/repairers/__init__.py @@ -4,13 +4,12 @@ from __future__ import annotations from pathlib import Path -import pandas as pd - from dataforge.detectors.base import Issue, Schema from dataforge.repairers.base import ProposedFix, RepairAttempt, Repairer, RetryContext from dataforge.repairers.decimal_shift import DecimalShiftRepairer from dataforge.repairers.fd_violation import FDViolationRepairer from dataforge.repairers.type_mismatch import TypeMismatchRepairer +from dataforge.table import TableLike __all__ = [ "DecimalShiftRepairer", @@ -45,7 +44,7 @@ def build_repairers( def propose_fixes( issues: list[Issue], - df: pd.DataFrame, + df: TableLike, schema: Schema | None, *, cache_dir: Path | None, diff --git a/dataforge/repairers/base.py b/dataforge/repairers/base.py index 501e5edcc2d74a9dbe486bd40b0b4c5d60d07e0e..fce975ad8639506259eee45d2679e80125b03622 100644 --- a/dataforge/repairers/base.py +++ b/dataforge/repairers/base.py @@ -4,10 +4,10 @@ from __future__ import annotations from typing import Literal, Protocol -import pandas as pd from pydantic import BaseModel, Field from dataforge.detectors.base import Issue, Schema +from dataforge.table import TableLike from dataforge.transactions.txn import CellFix ProvenanceLiteral = Literal["deterministic", "llm_cache", "llm_live"] @@ -69,7 +69,7 @@ class Repairer(Protocol): def propose( self, issue: Issue, - df: pd.DataFrame, + df: TableLike, schema: Schema | None, retry_context: RetryContext | None = None, ) -> ProposedFix | None: diff --git a/dataforge/repairers/decimal_shift.py b/dataforge/repairers/decimal_shift.py index 870ec18bacd49086766fff37c1ef2d686b44b29e..eb5677ce62133bdc17bb7323b1cd75af195d0373 100644 --- a/dataforge/repairers/decimal_shift.py +++ b/dataforge/repairers/decimal_shift.py @@ -2,10 +2,9 @@ from __future__ import annotations -import pandas as pd - from dataforge.detectors.base import Issue, Schema from dataforge.repairers.base import ProposedFix, RetryContext +from dataforge.table import TableLike, cell_value, column_names, row_count from dataforge.transactions.txn import CellFix @@ -15,7 +14,7 @@ class DecimalShiftRepairer: def propose( self, issue: Issue, - df: pd.DataFrame, + df: TableLike, schema: Schema | None, retry_context: RetryContext | None = None, ) -> ProposedFix | None: @@ -23,10 +22,10 @@ class DecimalShiftRepairer: del schema, retry_context if issue.issue_type != "decimal_shift" or issue.expected is None: return None - if issue.row >= len(df.index) or issue.column not in df.columns: + if issue.row >= row_count(df) or issue.column not in column_names(df): return None - old_value = str(df.at[issue.row, issue.column]) + old_value = cell_value(df, issue.row, issue.column) if old_value == issue.expected: return None diff --git a/dataforge/repairers/fd_violation.py b/dataforge/repairers/fd_violation.py index a18d598fb2c1d3c9cd2b1a1f735cd64c0976ee3e..f587e42123a674b3157af482fd7d00ed19efc33e 100644 --- a/dataforge/repairers/fd_violation.py +++ b/dataforge/repairers/fd_violation.py @@ -6,16 +6,17 @@ import asyncio import json from collections import Counter from pathlib import Path -from typing import TypedDict +from typing import TYPE_CHECKING, TypedDict -import pandas as pd - -from dataforge.agent.providers import Message, complete from dataforge.detectors.base import FunctionalDependency, Issue, Schema from dataforge.repairers.base import ProposedFix, ProvenanceLiteral, RetryContext +from dataforge.table import TableLike, cell_value, column_names, row_count from dataforge.transactions.log import sha256_bytes from dataforge.transactions.txn import CellFix +if TYPE_CHECKING: + from dataforge.agent.providers import Message + def _normalize_cell(value: object) -> str: """Normalize a DataFrame cell into a comparable string.""" @@ -46,7 +47,7 @@ class FDViolationRepairer: def _propose( self, issue: Issue, - df: pd.DataFrame, + df: TableLike, schema: Schema | None, retry_context: RetryContext | None, ) -> ProposedFix | None: @@ -54,7 +55,7 @@ class FDViolationRepairer: del retry_context if issue.issue_type != "fd_violation" or schema is None: return None - if issue.row >= len(df.index) or issue.column not in df.columns: + if issue.row >= row_count(df) or issue.column not in column_names(df): return None for fd in schema.functional_dependencies: @@ -64,11 +65,11 @@ class FDViolationRepairer: if group_df is None: continue - counts = Counter(_normalize_cell(value) for value in group_df[fd.dependent]) + counts = Counter(row[fd.dependent] for row in group_df) if len(counts) <= 1: continue - old_value = _normalize_cell(df.at[issue.row, issue.column]) + old_value = cell_value(df, issue.row, issue.column) chosen_majority = self._deterministic_choice(counts) if chosen_majority is not None: if chosen_majority == old_value: @@ -85,7 +86,7 @@ class FDViolationRepairer: def propose( self, issue: Issue, - df: pd.DataFrame, + df: TableLike, schema: Schema | None, retry_context: RetryContext | None = None, ) -> ProposedFix | None: @@ -94,23 +95,27 @@ class FDViolationRepairer: def _matching_group( self, - df: pd.DataFrame, + df: TableLike, row_index: int, fd: FunctionalDependency, - ) -> pd.DataFrame | None: + ) -> list[dict[str, str]] | None: """Return the determinant group containing the issue row.""" required_columns = [*fd.determinant, fd.dependent] - if any(column not in df.columns for column in required_columns): + if any(column not in column_names(df) for column in required_columns): return None - mask = pd.Series([True] * len(df.index), index=df.index) - for column in fd.determinant: - mask &= df[column].astype(str) == _normalize_cell(df.at[row_index, column]) - - group_df = df.loc[mask, required_columns] - if group_df.empty: + determinant_values = { + column: cell_value(df, row_index, column) for column in fd.determinant + } + group_rows: list[dict[str, str]] = [] + for row in range(row_count(df)): + if all(cell_value(df, row, column) == value for column, value in determinant_values.items()): + group_rows.append( + {column: cell_value(df, row, column) for column in required_columns} + ) + if not group_rows: return None - return group_df + return group_rows @staticmethod def _deterministic_choice(counts: Counter[str]) -> str | None: @@ -125,7 +130,7 @@ class FDViolationRepairer: def _choose_with_cache( self, fd: FunctionalDependency, - group_df: pd.DataFrame, + group_df: list[dict[str, str]], old_value: str, ) -> _Choice | None: """Choose a repaired value via cache-backed LLM fallback.""" @@ -135,7 +140,7 @@ class FDViolationRepairer: prompt_payload = { "determinant": fd.determinant, "dependent": fd.dependent, - "rows": group_df.to_dict(orient="records"), + "rows": group_df, "current_value": old_value, } prompt_text = json.dumps(prompt_payload, sort_keys=True) @@ -147,6 +152,14 @@ class FDViolationRepairer: chosen_value = str(cached["chosen_value"]) return {"value": chosen_value, "provenance": "llm_cache"} + try: + from dataforge.agent.providers import complete + except ImportError as exc: + raise RuntimeError( + "LLM-backed FD repair requires the provider extra: " + "pip install 'dataforge15[providers]'." + ) from exc + messages: list[Message] = [ { "role": "system", diff --git a/dataforge/repairers/type_mismatch.py b/dataforge/repairers/type_mismatch.py index 29a008a32b34574faf1852b972a43ae1e4ac8150..d5acfa87641364cc512930f318a2d0c043216e17 100644 --- a/dataforge/repairers/type_mismatch.py +++ b/dataforge/repairers/type_mismatch.py @@ -3,11 +3,11 @@ from __future__ import annotations import re - -import pandas as pd +from collections.abc import Iterable from dataforge.detectors.base import Issue, Schema from dataforge.repairers.base import ProposedFix, RetryContext +from dataforge.table import TableLike, cell_value, column_names, column_values, row_count from dataforge.transactions.txn import CellFix _NUMERIC_RE = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?$") @@ -21,7 +21,7 @@ def _looks_numeric(value: str) -> bool: return bool(_NUMERIC_RE.match(value.strip())) -def _is_predominantly_numeric(series: pd.Series) -> bool: +def _is_predominantly_numeric(series: Iterable[object]) -> bool: """Return whether the non-empty series is mostly numeric strings.""" normalized = [str(value).strip() for value in series if str(value).strip()] if not normalized: @@ -36,7 +36,7 @@ class TypeMismatchRepairer: def propose( self, issue: Issue, - df: pd.DataFrame, + df: TableLike, schema: Schema | None, retry_context: RetryContext | None = None, ) -> ProposedFix | None: @@ -44,14 +44,14 @@ class TypeMismatchRepairer: del schema, retry_context if issue.issue_type != "type_mismatch": return None - if issue.row >= len(df.index) or issue.column not in df.columns: + if issue.row >= row_count(df) or issue.column not in column_names(df): return None - old_value = str(df.at[issue.row, issue.column]) + old_value = cell_value(df, issue.row, issue.column) normalized_old = old_value.strip().lower() if normalized_old not in _MISSING_SENTINELS: return None - if not _is_predominantly_numeric(df[issue.column]): + if not _is_predominantly_numeric(column_values(df, issue.column)): return None if old_value == "": return None diff --git a/dataforge/safety/constitution.py b/dataforge/safety/constitution.py index 691dddcd106afa75a5a09d06053bb6f59f7ae5c0..fc40d1dbda45893cbcd2e691fde79977755cac01 100644 --- a/dataforge/safety/constitution.py +++ b/dataforge/safety/constitution.py @@ -5,6 +5,7 @@ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass from functools import lru_cache +from importlib import resources from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -157,7 +158,7 @@ class Constitution: def default_constitution_path() -> Path: """Return the shipped default constitution path.""" - return Path(__file__).resolve().parents[2] / "constitutions" / "default.yaml" + return Path(str(resources.files("dataforge.safety").joinpath("constitutions/default.yaml"))) def _expect_mapping(payload: object, *, message: str) -> dict[str, object]: diff --git a/dataforge/safety/constitutions/default.yaml b/dataforge/safety/constitutions/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75ad9ecadd16519e9e3e485d1e3b0944a4bf0597 --- /dev/null +++ b/dataforge/safety/constitutions/default.yaml @@ -0,0 +1,24 @@ +hard_never: + - id: NO_PII_OVERWRITE + description: Columns tagged as PII cannot be modified unless the operator explicitly allows and confirms it. + predicate: pii_overwrite + override_flag: allow_pii + confirm_flag: confirm_pii + - id: NO_ROW_DELETE + description: Candidate fixes may not delete rows; only cell updates are allowed. + predicate: row_delete + - id: NO_CONFLICTING_CELL_WRITES + description: Multiple accepted fixes may not target the same cell with different replacement values. + predicate: conflicting_cell_writes + scope: batch + +soft_require_confirm: + - id: NO_AGGREGATE_BREAK + description: Candidate fixes to aggregate-sensitive source columns require confirmation. + predicate: aggregate_sensitive + confirm_flag: confirm_escalations + +soft_prefer: + - id: MINIMAL_EDIT + description: Prefer the smallest edit distance when multiple candidates are otherwise viable. + scorer: minimal_edit_distance diff --git a/dataforge/table.py b/dataforge/table.py new file mode 100644 index 0000000000000000000000000000000000000000..aae94a3ee66eb81429f3817f8b77210f7e3cf248 --- /dev/null +++ b/dataforge/table.py @@ -0,0 +1,197 @@ +"""Small string-preserving table primitives for DataForge core paths. + +The CLI hot path should not need pandas just to profile or repair a CSV. +This module provides the narrow DataFrame-like surface that DataForge's +detectors, repairers, and verifier actually need. +""" + +from __future__ import annotations + +import csv +import io +from collections.abc import Iterable, Iterator, Sequence +from pathlib import Path +from typing import Any, Protocol, overload + + +class TableLike(Protocol): + """Protocol for the tabular surface consumed by DataForge core logic.""" + + @property + def columns(self) -> Sequence[str]: ... + + @property + def index(self) -> Sequence[int]: ... + + +class ColumnView(Sequence[str]): + """Read-only column view with the small API repairers expect.""" + + def __init__(self, values: Sequence[str]) -> None: + self._values = values + + def __iter__(self) -> Iterator[str]: + return iter(self._values) + + def __len__(self) -> int: + return len(self._values) + + def __getitem__(self, index: int) -> str: + return self._values[index] + + def tolist(self) -> list[str]: + """Return a list copy, matching pandas Series enough for detectors.""" + return list(self._values) + + +class _AtIndexer: + """``table.at[row, column]`` getter/setter compatibility shim.""" + + def __init__(self, table: Table) -> None: + self._table = table + + def __getitem__(self, key: tuple[int, str]) -> str: + row, column = key + return self._table.cell(row, column) + + def __setitem__(self, key: tuple[int, str], value: object) -> None: + row, column = key + self._table.set_cell(row, column, value) + + +class Table: + """In-memory CSV table with string-preserving cells.""" + + def __init__(self, columns: Sequence[str], rows: Iterable[dict[str, object]]) -> None: + self._columns = [str(column) for column in columns] + self._rows: list[dict[str, str]] = [ + {column: "" if row.get(column) is None else str(row.get(column, "")) for column in self._columns} + for row in rows + ] + self.at = _AtIndexer(self) + + @property + def columns(self) -> list[str]: + """Return column names in CSV order.""" + return list(self._columns) + + @property + def index(self) -> range: + """Return zero-based row positions.""" + return range(len(self._rows)) + + @property + def empty(self) -> bool: + """Return whether the table has no rows.""" + return not self._rows + + @overload + def __getitem__(self, key: str) -> ColumnView: ... + + @overload + def __getitem__(self, key: Sequence[str]) -> Table: ... + + def __getitem__(self, key: str | Sequence[str]) -> ColumnView | Table: + if isinstance(key, str): + if key not in self._columns: + raise KeyError(key) + return ColumnView([row.get(key, "") for row in self._rows]) + columns = [str(column) for column in key] + for column in columns: + if column not in self._columns: + raise KeyError(column) + return Table(columns, ({column: row.get(column, "") for column in columns} for row in self._rows)) + + def __len__(self) -> int: + return len(self._rows) + + def copy(self, deep: bool = True) -> Table: + """Return an independent table copy.""" + del deep + return Table(self._columns, (dict(row) for row in self._rows)) + + def cell(self, row: int, column: str) -> str: + """Return a cell value.""" + if column not in self._columns: + raise KeyError(column) + return self._rows[row].get(column, "") + + def set_cell(self, row: int, column: str, value: object) -> None: + """Set a cell value after validating the column.""" + if column not in self._columns: + raise KeyError(column) + self._rows[row][column] = "" if value is None else str(value) + + def iter_records(self, columns: Sequence[str] | None = None) -> Iterator[dict[str, str]]: + """Yield row dictionaries in table order.""" + selected = self._columns if columns is None else [str(column) for column in columns] + for row in self._rows: + yield {column: row.get(column, "") for column in selected} + + def to_dict(self, orient: str = "records") -> list[dict[str, str]]: + """Return records in the pandas-compatible orientation used by DataForge.""" + if orient != "records": + raise ValueError("Only orient='records' is supported.") + return list(self.iter_records()) + + def to_csv(self, buffer: io.StringIO, *, index: bool = False, lineterminator: str = "\n") -> None: + """Write the table as CSV to a text buffer.""" + if index: + raise ValueError("Table.to_csv does not support index=True.") + writer = csv.DictWriter(buffer, fieldnames=self._columns, lineterminator=lineterminator) + writer.writeheader() + for row in self._rows: + writer.writerow({column: row.get(column, "") for column in self._columns}) + + +def read_csv(path: Path) -> Table: + """Read a CSV as a string-preserving ``Table``.""" + with path.open("r", encoding="utf-8-sig", newline="") as handle: + reader = csv.DictReader(handle) + columns = list(reader.fieldnames or []) + return Table(columns, reader) + + +def table_to_csv_bytes(table: TableLike) -> bytes: + """Serialize a table-like object to UTF-8 CSV bytes.""" + output = io.StringIO() + if isinstance(table, Table): + table.to_csv(output, index=False, lineterminator="\n") + else: + # pandas-compatible fallback for tests and optional integrations. + table.to_csv(output, index=False, lineterminator="\n") # type: ignore[attr-defined] + return output.getvalue().encode("utf-8") + + +def column_names(table: TableLike) -> list[str]: + """Return table column names as strings.""" + return [str(column) for column in table.columns] + + +def row_count(table: TableLike) -> int: + """Return the number of rows in a table-like object.""" + return len(table.index) + + +def column_values(table: TableLike, column: str) -> list[Any]: + """Return all values for one column.""" + values = table[column] # type: ignore[index] + if hasattr(values, "tolist"): + return list(values.tolist()) + return list(values) + + +def cell_value(table: TableLike, row: int, column: str) -> str: + """Return a cell value as a string.""" + return str(table.at[row, column]) # type: ignore[attr-defined,index] + + +def set_cell_value(table: TableLike, row: int, column: str, value: object) -> None: + """Set a cell value on a table-like object.""" + table.at[row, column] = value # type: ignore[attr-defined,index] + + +def copy_table(table: TableLike) -> TableLike: + """Return a deep copy of a table-like object.""" + return table.copy(deep=True) # type: ignore[attr-defined] + diff --git a/dataforge/transactions/__init__.py b/dataforge/transactions/__init__.py index 7fb2fe606371ed6ab9e4049170d022d7fd51f60c..6c7ccdfe0308bc3d539064eb99b9a906a7ab2398 100644 --- a/dataforge/transactions/__init__.py +++ b/dataforge/transactions/__init__.py @@ -1,17 +1,22 @@ """Transaction exports for DataForge.""" from dataforge.transactions.log import ( + TransactionAuditReport, + TransactionAuditVerdict, append_applied_event, append_created_transaction, append_reverted_event, find_transaction_log, load_transaction, + verify_transaction_log, ) from dataforge.transactions.revert import TransactionRevertError, revert_transaction from dataforge.transactions.txn import CellFix, RepairTransaction, generate_txn_id __all__ = [ "CellFix", + "TransactionAuditReport", + "TransactionAuditVerdict", "RepairTransaction", "TransactionRevertError", "append_applied_event", @@ -21,4 +26,5 @@ __all__ = [ "generate_txn_id", "load_transaction", "revert_transaction", + "verify_transaction_log", ] diff --git a/dataforge/transactions/log.py b/dataforge/transactions/log.py index ba39237fbd734990b3ffb0a45f7d0343dbfc9e90..612b09d60c2120d86ec87f6ab5d32998a3aa58ee 100644 --- a/dataforge/transactions/log.py +++ b/dataforge/transactions/log.py @@ -2,21 +2,51 @@ from __future__ import annotations +import enum import hashlib import json +import re from datetime import UTC, datetime from pathlib import Path from typing import Any +from pydantic import BaseModel, Field + from dataforge.transactions.txn import RepairTransaction -SCHEMA_VERSION = 1 +LEGACY_SCHEMA_VERSION = 1 +SCHEMA_VERSION = 2 +_SHA256_RE = re.compile(r"^[0-9a-f]{64}$") class TransactionLogError(Exception): """Raised when a transaction journal cannot be written or replayed.""" +class TransactionAuditVerdict(enum.Enum): + """Possible outcomes for transaction log audit verification.""" + + VERIFIED = "verified" + LEGACY_UNVERIFIED = "legacy_unverified" + TAMPERED = "tampered" + MISSING = "missing" + MALFORMED = "malformed" + + +class TransactionAuditReport(BaseModel): + """Machine-readable result of transaction hash-chain verification.""" + + verdict: TransactionAuditVerdict + log_path: str | None = None + txn_id: str | None = None + schema_version: int | None = None + event_count: int = Field(ge=0) + head_sha256: str | None = Field(default=None, pattern=r"^[0-9a-f]{64}$") + errors: tuple[str, ...] = Field(default_factory=tuple) + + model_config = {"frozen": True} + + def sha256_bytes(payload: bytes) -> str: """Return the SHA-256 digest for the given payload.""" return hashlib.sha256(payload).hexdigest() @@ -62,6 +92,29 @@ def _utc_now() -> datetime: return datetime.now(UTC) +def _canonical_event_bytes(record: dict[str, Any]) -> bytes: + """Serialize an audit event into the canonical hash material.""" + unsigned = {key: value for key, value in record.items() if key != "event_sha256"} + return json.dumps( + unsigned, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + + +def _event_sha256(record: dict[str, Any]) -> str: + """Return the canonical SHA-256 hash for an event record.""" + return sha256_bytes(_canonical_event_bytes(record)) + + +def _sign_event(record: dict[str, Any]) -> dict[str, Any]: + """Return a copy of ``record`` with its canonical event hash attached.""" + signed = dict(record) + signed["event_sha256"] = _event_sha256(signed) + return signed + + def _write_jsonl_line(path: Path, record: dict[str, Any], *, create: bool = False) -> None: """Append or create a JSONL record on disk. @@ -83,6 +136,129 @@ def _write_jsonl_line(path: Path, record: dict[str, Any], *, create: bool = Fals raise TransactionLogError(f"Could not write transaction log '{path}': {exc}") from exc +def _read_records(log_path: Path) -> list[dict[str, Any]]: + """Read non-empty JSONL records from a transaction log.""" + records: list[dict[str, Any]] = [] + for line_number, raw_line in enumerate(log_path.read_text(encoding="utf-8").splitlines(), 1): + if not raw_line.strip(): + continue + try: + payload = json.loads(raw_line) + except json.JSONDecodeError as exc: + raise TransactionLogError( + f"Malformed JSON at {log_path}:{line_number}: {exc.msg}" + ) from exc + if not isinstance(payload, dict): + raise TransactionLogError(f"Malformed transaction event at {log_path}:{line_number}.") + records.append(payload) + return records + + +def _log_schema_version(log_path: Path) -> int | None: + """Return the first event schema version for an existing log.""" + if not log_path.exists(): + return None + records = _read_records(log_path) + if not records: + return None + raw_version = records[0].get("schema_version") + return raw_version if isinstance(raw_version, int) else None + + +def _next_event_metadata(log_path: Path) -> tuple[int, str | None]: + """Return the next v2 event index and previous hash for ``log_path``.""" + records = _read_records(log_path) + if not records: + raise TransactionLogError(f"Transaction log '{log_path}' contained no events.") + previous = records[-1].get("event_sha256") + if not isinstance(previous, str) or not _SHA256_RE.fullmatch(previous): + raise TransactionLogError( + f"Transaction log '{log_path}' is missing a valid previous event hash." + ) + return len(records), previous + + +def _v1_created_record(transaction: RepairTransaction) -> dict[str, Any]: + """Build a legacy v1 transaction creation event.""" + return { + "schema_version": LEGACY_SCHEMA_VERSION, + "event_type": "created", + "occurred_at": transaction.created_at.isoformat(), + "transaction": transaction.model_dump(mode="json"), + } + + +def _v2_created_record(transaction: RepairTransaction) -> dict[str, Any]: + """Build a hash-chained v2 transaction creation event.""" + return _sign_event( + { + "schema_version": SCHEMA_VERSION, + "event_index": 0, + "event_type": "created", + "occurred_at": transaction.created_at.isoformat(), + "previous_event_sha256": None, + "transaction": transaction.model_dump(mode="json"), + } + ) + + +def _v1_applied_record(txn_id: str, post_sha256: str, applied_at: datetime) -> dict[str, Any]: + """Build a legacy v1 applied event.""" + return { + "schema_version": LEGACY_SCHEMA_VERSION, + "event_type": "applied", + "occurred_at": applied_at.isoformat(), + "txn_id": txn_id, + "post_sha256": post_sha256, + } + + +def _v2_applied_record( + log_path: Path, + txn_id: str, + post_sha256: str, + applied_at: datetime, +) -> dict[str, Any]: + """Build a hash-chained v2 applied event.""" + event_index, previous_hash = _next_event_metadata(log_path) + return _sign_event( + { + "schema_version": SCHEMA_VERSION, + "event_index": event_index, + "event_type": "applied", + "occurred_at": applied_at.isoformat(), + "previous_event_sha256": previous_hash, + "txn_id": txn_id, + "post_sha256": post_sha256, + } + ) + + +def _v1_reverted_record(txn_id: str, reverted_at: datetime) -> dict[str, Any]: + """Build a legacy v1 reverted event.""" + return { + "schema_version": LEGACY_SCHEMA_VERSION, + "event_type": "reverted", + "occurred_at": reverted_at.isoformat(), + "txn_id": txn_id, + } + + +def _v2_reverted_record(log_path: Path, txn_id: str, reverted_at: datetime) -> dict[str, Any]: + """Build a hash-chained v2 reverted event.""" + event_index, previous_hash = _next_event_metadata(log_path) + return _sign_event( + { + "schema_version": SCHEMA_VERSION, + "event_index": event_index, + "event_type": "reverted", + "occurred_at": reverted_at.isoformat(), + "previous_event_sha256": previous_hash, + "txn_id": txn_id, + } + ) + + def append_created_transaction(transaction: RepairTransaction) -> Path: """Write the immutable transaction creation event. @@ -94,13 +270,7 @@ def append_created_transaction(transaction: RepairTransaction) -> Path: """ source_path = Path(transaction.source_path) log_path = transaction_log_path_for(source_path, transaction.txn_id) - record = { - "schema_version": SCHEMA_VERSION, - "event_type": "created", - "occurred_at": transaction.created_at.isoformat(), - "transaction": transaction.model_dump(mode="json"), - } - _write_jsonl_line(log_path, record, create=True) + _write_jsonl_line(log_path, _v2_created_record(transaction), create=True) return log_path @@ -112,13 +282,12 @@ def append_applied_event( applied_at: datetime | None = None, ) -> None: """Append an ``applied`` event to an existing transaction log.""" - record = { - "schema_version": SCHEMA_VERSION, - "event_type": "applied", - "occurred_at": (applied_at or _utc_now()).isoformat(), - "txn_id": txn_id, - "post_sha256": post_sha256, - } + occurred_at = applied_at or _utc_now() + record = ( + _v1_applied_record(txn_id, post_sha256, occurred_at) + if _log_schema_version(log_path) == LEGACY_SCHEMA_VERSION + else _v2_applied_record(log_path, txn_id, post_sha256, occurred_at) + ) _write_jsonl_line(log_path, record, create=False) @@ -129,12 +298,12 @@ def append_reverted_event( reverted_at: datetime | None = None, ) -> None: """Append a ``reverted`` event to an existing transaction log.""" - record = { - "schema_version": SCHEMA_VERSION, - "event_type": "reverted", - "occurred_at": (reverted_at or _utc_now()).isoformat(), - "txn_id": txn_id, - } + occurred_at = reverted_at or _utc_now() + record = ( + _v1_reverted_record(txn_id, occurred_at) + if _log_schema_version(log_path) == LEGACY_SCHEMA_VERSION + else _v2_reverted_record(log_path, txn_id, occurred_at) + ) _write_jsonl_line(log_path, record, create=False) @@ -154,11 +323,8 @@ def load_transaction(log_path: Path) -> RepairTransaction: raise TransactionLogError(f"Transaction log not found: {log_path}") transaction: RepairTransaction | None = None - for raw_line in log_path.read_text(encoding="utf-8").splitlines(): - if not raw_line.strip(): - continue - payload = json.loads(raw_line) - if payload.get("schema_version") != SCHEMA_VERSION: + for payload in _read_records(log_path): + if payload.get("schema_version") not in {LEGACY_SCHEMA_VERSION, SCHEMA_VERSION}: raise TransactionLogError( f"Unsupported transaction log schema version in '{log_path}'." ) @@ -230,3 +396,172 @@ def find_transaction_log(txn_id: str, *, search_root: Path | None = None) -> Pat if len(matches) > 1: raise TransactionLogError(f"Found multiple transaction logs for '{txn_id}' under '{root}'.") return matches[0] + + +def verify_transaction_log( + txn_id: str | None = None, + *, + log_path: Path | None = None, + search_root: Path | None = None, +) -> TransactionAuditReport: + """Verify a transaction log's local hash chain. + + Legacy v1 logs remain replayable but cannot be cryptographically verified, + so they return ``legacy_unverified`` instead of ``verified``. + """ + try: + resolved_log_path = log_path.resolve() if log_path is not None else None + if resolved_log_path is None: + if txn_id is None: + return TransactionAuditReport( + verdict=TransactionAuditVerdict.MISSING, + txn_id=txn_id, + event_count=0, + errors=("txn_id or log_path is required.",), + ) + resolved_log_path = find_transaction_log(txn_id, search_root=search_root) + except TransactionLogError as exc: + return TransactionAuditReport( + verdict=TransactionAuditVerdict.MISSING, + txn_id=txn_id, + event_count=0, + errors=(str(exc),), + ) + + if not resolved_log_path.exists(): + return TransactionAuditReport( + verdict=TransactionAuditVerdict.MISSING, + log_path=str(resolved_log_path), + txn_id=txn_id, + event_count=0, + errors=(f"Transaction log not found: {resolved_log_path}",), + ) + + try: + records = _read_records(resolved_log_path) + except TransactionLogError as exc: + return TransactionAuditReport( + verdict=TransactionAuditVerdict.MALFORMED, + log_path=str(resolved_log_path), + txn_id=txn_id, + event_count=0, + errors=(str(exc),), + ) + + if not records: + return TransactionAuditReport( + verdict=TransactionAuditVerdict.MALFORMED, + log_path=str(resolved_log_path), + txn_id=txn_id, + event_count=0, + errors=("Transaction log contained no events.",), + ) + + versions = {record.get("schema_version") for record in records} + if versions == {LEGACY_SCHEMA_VERSION}: + try: + transaction = load_transaction(resolved_log_path) + except TransactionLogError as exc: + return TransactionAuditReport( + verdict=TransactionAuditVerdict.MALFORMED, + log_path=str(resolved_log_path), + schema_version=LEGACY_SCHEMA_VERSION, + event_count=len(records), + errors=(str(exc),), + ) + if txn_id is not None and transaction.txn_id != txn_id: + return TransactionAuditReport( + verdict=TransactionAuditVerdict.TAMPERED, + log_path=str(resolved_log_path), + txn_id=transaction.txn_id, + schema_version=LEGACY_SCHEMA_VERSION, + event_count=len(records), + errors=(f"Expected txn_id '{txn_id}', found '{transaction.txn_id}'.",), + ) + return TransactionAuditReport( + verdict=TransactionAuditVerdict.LEGACY_UNVERIFIED, + log_path=str(resolved_log_path), + txn_id=transaction.txn_id, + schema_version=LEGACY_SCHEMA_VERSION, + event_count=len(records), + errors=("Legacy v1 logs do not contain event hashes.",), + ) + + if versions != {SCHEMA_VERSION}: + return TransactionAuditReport( + verdict=TransactionAuditVerdict.MALFORMED, + log_path=str(resolved_log_path), + txn_id=txn_id, + event_count=len(records), + errors=(f"Mixed or unsupported schema versions: {sorted(map(str, versions))}.",), + ) + + errors: list[str] = [] + previous_hash: str | None = None + resolved_txn_id: str | None = None + head_sha256: str | None = None + for expected_index, record in enumerate(records): + if record.get("event_index") != expected_index: + errors.append(f"Event {expected_index} has event_index {record.get('event_index')!r}.") + if record.get("previous_event_sha256") != previous_hash: + errors.append(f"Event {expected_index} previous hash does not match.") + + recorded_hash = record.get("event_sha256") + if not isinstance(recorded_hash, str) or not _SHA256_RE.fullmatch(recorded_hash): + errors.append(f"Event {expected_index} is missing a valid event hash.") + else: + calculated_hash = _event_sha256(record) + if calculated_hash != recorded_hash: + errors.append(f"Event {expected_index} hash does not match its payload.") + previous_hash = recorded_hash + head_sha256 = recorded_hash + + event_type = record.get("event_type") + if event_type == "created": + raw_transaction = record.get("transaction") + if not isinstance(raw_transaction, dict): + errors.append("Created event is missing a transaction payload.") + else: + current_txn_id = raw_transaction.get("txn_id") + if not isinstance(current_txn_id, str): + errors.append("Created transaction payload is missing txn_id.") + elif resolved_txn_id is None: + resolved_txn_id = current_txn_id + elif resolved_txn_id != current_txn_id: + errors.append("Created transaction payload changed txn_id.") + elif event_type in {"applied", "reverted"}: + current_txn_id = record.get("txn_id") + if current_txn_id != resolved_txn_id: + errors.append( + f"Event {expected_index} txn_id {current_txn_id!r} does not match created event." + ) + else: + errors.append(f"Event {expected_index} has unknown event_type {event_type!r}.") + + if txn_id is not None and resolved_txn_id is not None and resolved_txn_id != txn_id: + errors.append(f"Expected txn_id '{txn_id}', found '{resolved_txn_id}'.") + + try: + load_transaction(resolved_log_path) + except TransactionLogError as exc: + errors.append(str(exc)) + + if errors: + return TransactionAuditReport( + verdict=TransactionAuditVerdict.TAMPERED, + log_path=str(resolved_log_path), + txn_id=resolved_txn_id or txn_id, + schema_version=SCHEMA_VERSION, + event_count=len(records), + head_sha256=head_sha256, + errors=tuple(errors), + ) + + return TransactionAuditReport( + verdict=TransactionAuditVerdict.VERIFIED, + log_path=str(resolved_log_path), + txn_id=resolved_txn_id, + schema_version=SCHEMA_VERSION, + event_count=len(records), + head_sha256=head_sha256, + ) diff --git a/dataforge/transactions/revert.py b/dataforge/transactions/revert.py index 2a46c2a1f31bf7012b0fc771e826013a534b0901..f24a92e806ea41c2007cced824f7dfc59b75187e 100644 --- a/dataforge/transactions/revert.py +++ b/dataforge/transactions/revert.py @@ -5,10 +5,12 @@ from __future__ import annotations from pathlib import Path from dataforge.transactions.log import ( + TransactionAuditVerdict, append_reverted_event, find_transaction_log, load_transaction, sha256_file, + verify_transaction_log, ) from dataforge.transactions.txn import RepairTransaction @@ -31,6 +33,15 @@ def revert_transaction(txn_id: str, *, search_root: Path | None = None) -> Repai TransactionRevertError: If the transaction is not revertible or hash checks fail. """ log_path = find_transaction_log(txn_id, search_root=search_root) + audit_report = verify_transaction_log(txn_id, log_path=log_path) + if audit_report.verdict not in { + TransactionAuditVerdict.VERIFIED, + TransactionAuditVerdict.LEGACY_UNVERIFIED, + }: + details = "; ".join(audit_report.errors) or audit_report.verdict.value + raise TransactionRevertError( + f"Refusing to revert because transaction audit verification failed: {details}" + ) transaction = load_transaction(log_path) if not transaction.applied or transaction.post_sha256 is None: diff --git a/dataforge/ui/repair_diff.py b/dataforge/ui/repair_diff.py index 0b41b17039d8142e830ddf17172598113ae06d7f..eb80f348d89bbdb50e6606393221968a6579decb 100644 --- a/dataforge/ui/repair_diff.py +++ b/dataforge/ui/repair_diff.py @@ -2,15 +2,54 @@ from __future__ import annotations +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, cast + from rich.console import Console from rich.panel import Panel from rich.table import Table -from dataforge.repairers.base import ProposedFix + +@dataclass(frozen=True) +class _RepairDiffRow: + row: int + column: str + old_value: str + new_value: str + detector_id: str + confidence: float + provenance: str + + +def _row_from_fix(fix: object) -> _RepairDiffRow: + """Normalize internal ProposedFix and public VerifiedFix objects.""" + nested = getattr(fix, "fix", None) + if nested is None: + public = cast(Any, fix) + return _RepairDiffRow( + row=public.row, + column=public.column, + old_value=public.old_value, + new_value=public.new_value, + detector_id=public.detector_id, + confidence=public.confidence, + provenance=public.provenance, + ) + proposed = cast(Any, fix) + return _RepairDiffRow( + row=nested.row, + column=nested.column, + old_value=nested.old_value, + new_value=nested.new_value, + detector_id=nested.detector_id, + confidence=proposed.confidence, + provenance=proposed.provenance, + ) def render_repair_diff( - fixes: list[ProposedFix], + fixes: Sequence[object], console: Console | None = None, *, file_path: str = "", @@ -37,15 +76,16 @@ def render_repair_diff( table.add_column("Confidence", justify="right", min_width=10) table.add_column("Provenance", min_width=13) - for proposed in fixes: + for fix in fixes: + row = _row_from_fix(fix) table.add_row( - str(proposed.fix.row), - proposed.fix.column, - proposed.fix.old_value, - proposed.fix.new_value, - proposed.fix.detector_id, - f"{proposed.confidence:.0%}", - proposed.provenance, + str(row.row), + row.column, + row.old_value, + row.new_value, + row.detector_id, + f"{row.confidence:.0%}", + row.provenance, ) target_console.print(table) diff --git a/dataforge/verifier/smt.py b/dataforge/verifier/smt.py index 650d740d55050410003969c039a53edf9eb76a63..46a2fc0e6f3b7b298ff07d0cf7d12f89e56e5215 100644 --- a/dataforge/verifier/smt.py +++ b/dataforge/verifier/smt.py @@ -7,7 +7,6 @@ from collections.abc import Callable from dataclasses import dataclass from typing import Any -import pandas as pd from pydantic import BaseModel, Field from z3 import ( # type: ignore[import-untyped] And, @@ -29,6 +28,14 @@ from z3 import ( # type: ignore[import-untyped] ) from dataforge.repairers.base import ProposedFix +from dataforge.table import ( + TableLike, + cell_value, + column_names, + copy_table, + row_count, + set_cell_value, +) from dataforge.verifier.explain import explain_unsat_core from dataforge.verifier.schema import DomainBound, FunctionalDependency, Schema @@ -67,7 +74,7 @@ class _ColumnEncoding: class SchemaToSMT: """Compile candidate-local constraints from a schema and working dataframe.""" - def __init__(self, schema: Schema, df: pd.DataFrame, *, timeout_ms: int = 200) -> None: + def __init__(self, schema: Schema, df: TableLike, *, timeout_ms: int = 200) -> None: self._schema = schema self._df = df self._timeout_ms = timeout_ms @@ -82,12 +89,12 @@ class SchemaToSMT: row = proposed_fix.fix.row column = proposed_fix.fix.column - if row < 0 or row >= len(self._df.index): + if row < 0 or row >= row_count(self._df): return VerificationResult( verdict=VerificationVerdict.REJECT, reason=f"Row {row} is out of bounds for the input file.", ) - if column not in self._df.columns: + if column not in column_names(self._df): return VerificationResult( verdict=VerificationVerdict.REJECT, reason=f"Column '{column}' does not exist in the input file.", @@ -186,8 +193,8 @@ class SchemaToSMT: proposed_fix: ProposedFix, ) -> None: for column, encoding in encodings.items(): - for index in range(len(self._df.index)): - raw_value = str(self._df.at[index, column]) + for index in range(row_count(self._df)): + raw_value = cell_value(self._df, index, column) if index == proposed_fix.fix.row and column == proposed_fix.fix.column: raw_value = proposed_fix.fix.new_value try: @@ -235,7 +242,7 @@ class SchemaToSMT: ) -> None: # Use a universally-quantified implication over all valid other rows. other_row = Int("other_row") - bounds_guard = And(other_row >= 0, other_row < len(self._df.index)) + bounds_guard = And(other_row >= 0, other_row < row_count(self._df)) candidate_row = IntVal(proposed_fix.fix.row) determinant_equal = And( *[ @@ -259,20 +266,20 @@ class SMTVerifier: def verify( self, - df: pd.DataFrame, + df: TableLike, fixes: list[ProposedFix], schema: Schema | None = None, ) -> VerificationResult: """Verify one or more candidate fixes against the working dataframe.""" if schema is None: - row_count = len(df.index) + total_rows = row_count(df) for proposed in fixes: - if proposed.fix.row < 0 or proposed.fix.row >= row_count: + if proposed.fix.row < 0 or proposed.fix.row >= total_rows: return VerificationResult( verdict=VerificationVerdict.REJECT, reason=f"Row {proposed.fix.row} is out of bounds for the input file.", ) - if proposed.fix.column not in df.columns: + if proposed.fix.column not in column_names(df): return VerificationResult( verdict=VerificationVerdict.REJECT, reason=f"Column '{proposed.fix.column}' does not exist in the input file.", @@ -282,13 +289,13 @@ class SMTVerifier: reason="All proposed fixes passed structural verification.", ) - working_df = df.copy(deep=True) + working_df = copy_table(df) verifier = SchemaToSMT(schema, working_df) for proposed in fixes: result = verifier.verify_fix(proposed) if result.verdict != VerificationVerdict.ACCEPT: return result - working_df.at[proposed.fix.row, proposed.fix.column] = proposed.fix.new_value + set_cell_value(working_df, proposed.fix.row, proposed.fix.column, proposed.fix.new_value) verifier = SchemaToSMT(schema, working_df) return VerificationResult( verdict=VerificationVerdict.ACCEPT, diff --git a/playground/api/app.py b/playground/api/app.py index b97219fc4e7e183733cf84e5c566255b92c5203b..929b28e55fb873ce936e26b89a92d7a35e06c4e8 100644 --- a/playground/api/app.py +++ b/playground/api/app.py @@ -1,146 +1,257 @@ -"""DataForge Playground — stateless FastAPI backend. +"""Stateless FastAPI backend for the hosted DataForge playground. -Provides /api/health, /api/profile, /api/repair, and /api/samples/{name} -endpoints for the hosted playground demo. All processing is ephemeral: -uploaded files are held in memory or per-request TemporaryDirectories and -discarded after the response. +The hosted playground is intentionally split across two free-tier hosts: -See specs/SPEC_playground.md for the full contract. +- Cloudflare Workers Static Assets serves the static frontend. +- Hugging Face Spaces serves this API-only backend. -Note: This module intentionally omits `from __future__ import annotations` -because FastAPI relies on runtime type inspection for UploadFile and other -parameter annotations. PEP 563 deferred evaluation breaks FastAPI's -dependency injection. +All uploaded data is processed in memory or under a per-request temporary +directory and is discarded before the request completes. """ import io import logging import os import tempfile +import time +from collections import defaultdict +from collections.abc import Callable +from importlib import import_module from pathlib import Path -from typing import Any +from typing import Any, Protocol, TypeVar, cast import pandas as pd from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse, JSONResponse, StreamingResponse -from fastapi.staticfiles import StaticFiles -from slowapi import Limiter -from slowapi.errors import RateLimitExceeded -from slowapi.util import get_remote_address -from starlette.middleware.base import BaseHTTPMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.responses import Response from starlette.types import ASGIApp -from dataforge.detectors import run_all_detectors -from dataforge.detectors.base import Issue, Severity -from dataforge.repairers import propose_fixes -from dataforge.repairers.base import ProposedFix -from dataforge.safety.filter import SafetyContext, SafetyFilter, SafetyVerdict -from dataforge.transactions.log import sha256_bytes -from dataforge.transactions.txn import generate_txn_id +from dataforge import ( + CONTRACT_VERSION, + Issue, + RepairPipelineRequest, + RepairTransaction, + Severity, + VerifiedFix, + run_all_detectors, + run_repair_pipeline, +) +from dataforge.http.problem import problem_exception_handler, problem_response +from dataforge.observability import configure_fastapi_observability + + +class FallbackRateLimitExceededError(Exception): + """Fallback exception shape matching slowapi's detail attribute.""" + + def __init__(self, detail: str) -> None: + super().__init__(detail) + self.detail = detail + + +try: + _slowapi_module = import_module("slowapi") + _slowapi_errors = import_module("slowapi.errors") + _slowapi_util = import_module("slowapi.util") + _SlowapiLimiter: Any | None = _slowapi_module.Limiter + _SlowapiRateLimitExceeded: type[Exception] | None = _slowapi_errors.RateLimitExceeded + get_remote_address = cast(Callable[[Request], str], _slowapi_util.get_remote_address) + + SLOWAPI_AVAILABLE = True +except ModuleNotFoundError: + _SlowapiLimiter = None + _SlowapiRateLimitExceeded = None + SLOWAPI_AVAILABLE = False + + def get_remote_address(request: Request) -> str: + """Return the client host for fallback rate-limit keys.""" + return request.client.host if request.client else "unknown" + + +_CallableT = TypeVar("_CallableT", bound=Callable[..., Any]) + + +class _StorageLike(Protocol): + """Minimal storage protocol used by tests and fallback middleware.""" + + def reset(self) -> None: ... + + +class _LimiterLike(Protocol): + """Minimal limiter protocol shared by slowapi and the fallback.""" + + _storage: _StorageLike + + def limit(self, limit_value: str) -> Callable[[_CallableT], _CallableT]: ... + + +class _FallbackStorage: + """Small in-memory windowed counter used when slowapi is unavailable.""" + + def __init__(self) -> None: + self._hits: dict[tuple[str, str], list[float]] = defaultdict(list) + + def reset(self) -> None: + """Clear all fallback counters.""" + self._hits.clear() + + def allow(self, key: tuple[str, str], *, limit: int, window_seconds: float) -> bool: + """Record a hit and return whether it fits inside the window.""" + now = time.monotonic() + hits = [seen for seen in self._hits[key] if now - seen < window_seconds] + hits.append(now) + self._hits[key] = hits + return len(hits) <= limit + + +class _FallbackLimiter: + """Decorator-compatible fallback limiter.""" + + def __init__(self) -> None: + self._storage: _StorageLike = _FallbackStorage() + + def limit(self, _limit_value: str) -> Callable[[_CallableT], _CallableT]: + """Return an identity decorator; middleware enforces the limit.""" + + def decorator(func: _CallableT) -> _CallableT: + return func + + return decorator + + +_RateLimitExceeded: type[Exception] = ( + _SlowapiRateLimitExceeded + if _SlowapiRateLimitExceeded is not None + else FallbackRateLimitExceededError +) logger = logging.getLogger("playground.api") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- - -MAX_UPLOAD_BYTES = 1_048_576 # 1 MB +MAX_UPLOAD_BYTES = 1_048_576 +MAX_MULTIPART_OVERHEAD_BYTES = 16_384 SAMPLES_DIR = Path(__file__).resolve().parent / "samples" -WEB_DIR = Path(__file__).resolve().parent / "web" +SLOWAPI_CONFIG = Path(__file__).resolve().parent / "slowapi.env" ALLOWED_SAMPLES = {"hospital_10rows", "flights_10rows", "beers_10rows"} -# --------------------------------------------------------------------------- -# Size-cap middleware -# --------------------------------------------------------------------------- - class SizeCapMiddleware(BaseHTTPMiddleware): - """Reject requests with Content-Length exceeding the upload cap. - - This prevents OOM on the free-tier Space by checking the declared - Content-Length header before the body is read. The endpoint handlers - additionally enforce a defensive read cap. - """ - - def __init__(self, app: ASGIApp, max_bytes: int = MAX_UPLOAD_BYTES) -> None: + """Reject requests whose declared Content-Length cannot contain a valid upload.""" + + def __init__( + self, + app: ASGIApp, + max_file_bytes: int = MAX_UPLOAD_BYTES, + max_multipart_overhead_bytes: int = MAX_MULTIPART_OVERHEAD_BYTES, + ) -> None: super().__init__(app) - self.max_bytes = max_bytes - - async def dispatch(self, request: Request, call_next: Any) -> Any: - """Check Content-Length and reject oversized requests.""" + self.max_file_bytes = max_file_bytes + self.max_body_bytes = max_file_bytes + max_multipart_overhead_bytes + + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + """Check Content-Length before any request body is read.""" content_length = request.headers.get("content-length") if content_length is not None: try: length = int(content_length) except ValueError: - return JSONResponse( - status_code=400, - content={"error": "invalid_content_length"}, + return JSONResponse(status_code=400, content={"error": "invalid_content_length"}) + if length > self.max_body_bytes: + logger.warning( + "Rejected request: Content-Length %d exceeds max body %d", + length, + self.max_body_bytes, ) - if length > self.max_bytes: - logger.warning("Rejected request: Content-Length %d > %d", length, self.max_bytes) - return JSONResponse( - status_code=413, - content={"error": "file_too_large", "max_bytes": self.max_bytes}, + return problem_response( + status=413, + type_="https://dataforge.local/problems/file_too_large", + title="File Too Large", + detail="The uploaded request body exceeds the playground limit.", + instance=str(request.url.path), + error="file_too_large", + max_bytes=self.max_file_bytes, ) return await call_next(request) -# --------------------------------------------------------------------------- -# Rate limiter (single-worker contract — see SPEC_playground.md §4) -# --------------------------------------------------------------------------- - -limiter = Limiter(key_func=get_remote_address) +class FallbackRateLimitMiddleware(BaseHTTPMiddleware): + """Enforce the playground POST limit when slowapi is not installed.""" + + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + """Apply a 10/minute in-memory fallback to mutating playground endpoints.""" + if request.method == "POST" and request.url.path in {"/api/profile", "/api/repair"}: + storage = limiter._storage + key = (get_remote_address(request), request.url.path) + if isinstance(storage, _FallbackStorage) and not storage.allow( + key, + limit=10, + window_seconds=60.0, + ): + return problem_response( + status=429, + type_="https://dataforge.local/problems/rate_limit_exceeded", + title="Rate Limit Exceeded", + detail="10 per 1 minute", + instance=str(request.url.path), + headers={"Retry-After": "60"}, + error="rate_limit_exceeded", + ) + return await call_next(request) -# --------------------------------------------------------------------------- -# CORS configuration -# --------------------------------------------------------------------------- +if _SlowapiLimiter is not None: + limiter: _LimiterLike = cast( + _LimiterLike, + _SlowapiLimiter(key_func=get_remote_address, config_filename=str(SLOWAPI_CONFIG)), + ) +else: + limiter = _FallbackLimiter() -def _build_cors_origins() -> list[str]: - """Build the CORS allow-origins list from environment.""" - origins: list[str] = [] - # Explicit origins from env - env_origins = os.environ.get("DATAFORGE_PLAYGROUND_ORIGINS", "") - if env_origins: - origins.extend(o.strip() for o in env_origins.split(",") if o.strip()) +def _advanced_available() -> bool: + """Return whether at least one backend LLM provider is configured.""" + return bool(os.environ.get("GROQ_API_KEY") or os.environ.get("GEMINI_API_KEY")) - # Localhost in dev mode - if os.environ.get("DATAFORGE_PLAYGROUND_DEV", "") == "1": - origins.append("http://localhost:3000") - origins.append("http://localhost:5500") - origins.append("http://localhost:8000") - origins.append("http://localhost:8080") - origins.append("http://127.0.0.1:3000") - origins.append("http://127.0.0.1:5500") - origins.append("http://127.0.0.1:8000") - origins.append("http://127.0.0.1:8080") - return origins +def _build_cors_origins() -> list[str]: + """Build the explicit CORS allowlist from the environment.""" + env_origins = os.environ.get("DATAFORGE_PLAYGROUND_ORIGINS", "") + return [origin.strip() for origin in env_origins.split(",") if origin.strip()] def _build_cors_origin_regex() -> str | None: - """Build regex for Cloudflare Pages preview and production URLs.""" - return r"https://.*\.pages\.dev" - + """Build the regex allowlist for local development only.""" + patterns: list[str] = [] + if os.environ.get("DATAFORGE_PLAYGROUND_DEV") == "1": + patterns.append(r"http://(?:localhost|127(?:\.\d{1,3}){3})(?::\d+)?") + if not patterns: + return None + return "^(" + "|".join(patterns) + ")$" -# --------------------------------------------------------------------------- -# App factory -# --------------------------------------------------------------------------- app = FastAPI( - title="DataForge Playground", - description="Stateless demo of DataForge profile and repair capabilities.", + title="DataForge Playground API", + description="Stateless backend for the hosted DataForge playground.", version="0.1.0", docs_url="/api/docs", redoc_url=None, ) - -# Middleware order matters: size cap first, then CORS -app.add_middleware(SizeCapMiddleware, max_bytes=MAX_UPLOAD_BYTES) +app.add_middleware( + SizeCapMiddleware, + max_file_bytes=MAX_UPLOAD_BYTES, + max_multipart_overhead_bytes=MAX_MULTIPART_OVERHEAD_BYTES, +) +if not SLOWAPI_AVAILABLE: + app.add_middleware(FallbackRateLimitMiddleware) app.add_middleware( CORSMiddleware, allow_origins=_build_cors_origins(), @@ -149,35 +260,28 @@ app.add_middleware( allow_headers=["*"], allow_credentials=False, ) - app.state.limiter = limiter - -# Mount static frontend if the web directory exists (co-located in HF Space) -if WEB_DIR.exists(): - app.mount("/static", StaticFiles(directory=str(WEB_DIR)), name="static") - - -@app.exception_handler(RateLimitExceeded) -async def _rate_limit_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: - """Return 429 with a machine-readable error body.""" - return JSONResponse( - status_code=429, - content={"error": "rate_limit_exceeded", "detail": str(exc.detail)}, +app.add_exception_handler(HTTPException, problem_exception_handler) +configure_fastapi_observability(app, service_name="dataforge-playground-api") + + +@app.exception_handler(_RateLimitExceeded) +async def _rate_limit_handler(request: Request, exc: Exception) -> JSONResponse: + """Return a machine-readable 429 response.""" + detail = str(getattr(exc, "detail", str(exc))) + return problem_response( + status=429, + type_="https://dataforge.local/problems/rate_limit_exceeded", + title="Rate Limit Exceeded", + detail=detail, + instance=str(request.url.path), + headers={"Retry-After": "60"}, + error="rate_limit_exceeded", ) -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - async def _read_upload(file: UploadFile) -> bytes: - """Read an uploaded file with a defensive size cap. - - Reads at most MAX_UPLOAD_BYTES + 1 byte. If the read exceeds the cap, - raises HTTPException(413) even if Content-Length was not set or was - spoofed. - """ + """Read an uploaded file with a defensive hard cap.""" data = await file.read(MAX_UPLOAD_BYTES + 1) if len(data) > MAX_UPLOAD_BYTES: raise HTTPException( @@ -188,7 +292,7 @@ async def _read_upload(file: UploadFile) -> bytes: def _csv_to_df(data: bytes) -> pd.DataFrame: - """Parse CSV bytes into a DataFrame using safe defaults.""" + """Parse CSV bytes into a string-preserving DataFrame.""" return pd.read_csv( io.BytesIO(data), dtype=str, @@ -198,92 +302,149 @@ def _csv_to_df(data: bytes) -> pd.DataFrame: def _severity_to_str(severity: Severity) -> str: - """Convert Severity enum to lowercase string.""" + """Convert a Severity enum into the JSON response value.""" return severity.value -def _issues_to_response(issues: list[Issue], df: pd.DataFrame) -> dict[str, Any]: - """Format detected issues into the playground JSON response shape.""" - # Group row indices by (column, issue_type, severity) for compact display +def _issues_to_response( + issues: list[Issue], + df: pd.DataFrame, + *, + advanced_requested: bool, +) -> dict[str, Any]: + """Format detected issues into the public playground JSON contract.""" grouped: dict[tuple[str, str, str], list[int]] = {} for issue in issues: key = (issue.column, issue.issue_type, _severity_to_str(issue.severity)) grouped.setdefault(key, []).append(issue.row) - issue_list = [] + payload_issues: list[dict[str, Any]] = [] for (column, issue_type, severity), row_indices in grouped.items(): - issue_list.append({ - "column": column, - "issue_type": issue_type, - "severity": severity, - "row_indices": sorted(set(row_indices)), - "count": len(set(row_indices)), - }) + unique_rows = sorted(set(row_indices)) + payload_issues.append( + { + "column": column, + "issue_type": issue_type, + "severity": severity, + "row_indices": unique_rows, + "count": len(unique_rows), + } + ) return { - "issues": issue_list, + "issues": payload_issues, "meta": { "rows": len(df), "columns": len(df.columns), "column_names": list(df.columns), "total_issues": len(issues), + "advanced_requested": advanced_requested, + "api_version": app.version, + "contract_version": CONTRACT_VERSION, }, } def _fixes_to_response( - fixes: list[ProposedFix], - txn_id: str, - source_sha256: str, + fixes: list[VerifiedFix], + transaction: RepairTransaction, + *, + source_name: str, ) -> dict[str, Any]: - """Format proposed fixes into the playground JSON response shape.""" - fix_list = [] - for pf in fixes: - fix_list.append({ - "row": pf.fix.row, - "column": pf.fix.column, - "old_value": pf.fix.old_value, - "new_value": pf.fix.new_value, - "detector_id": pf.fix.detector_id, - "reason": pf.reason, - "confidence": pf.confidence, - "provenance": pf.provenance, - }) + """Format accepted repair proposals plus a redacted transaction journal.""" + payload_fixes: list[dict[str, Any]] = [] + for proposed_fix in fixes: + payload_fixes.append( + { + "row": proposed_fix.row, + "column": proposed_fix.column, + "old_value": proposed_fix.old_value, + "new_value": proposed_fix.new_value, + "detector_id": proposed_fix.detector_id, + "reason": proposed_fix.reason, + "confidence": proposed_fix.confidence, + "provenance": proposed_fix.provenance, + } + ) return { - "fixes": fix_list, + "fixes": payload_fixes, "txn_journal": { - "txn_id": txn_id, - "source_sha256": source_sha256, - "fixes_count": len(fix_list), - "applied": False, + "txn_id": transaction.txn_id, + "created_at": transaction.created_at.isoformat(), + "source_name": source_name, + "source_sha256": transaction.source_sha256, + "fixes_count": len(transaction.fixes), + "applied": transaction.applied, + "events": [{"event_type": "created"}], "note": ( - "Playground is stateless. This transaction journal is ephemeral " - "and will not be persisted. Install the CLI to use " - "`dataforge repair --apply` and `dataforge revert`." + "Playground is stateless. This journal is ephemeral and discarded " + "after the response. Install the CLI to apply and revert repairs." ), }, + "meta": { + "api_version": app.version, + "contract_version": CONTRACT_VERSION, + }, } -# --------------------------------------------------------------------------- -# Endpoints -# --------------------------------------------------------------------------- +def _require_advanced_mode(advanced_requested: bool) -> None: + """Reject advanced mode requests unless a provider key is configured.""" + if advanced_requested and not _advanced_available(): + raise HTTPException(status_code=400, detail={"error": "advanced_mode_unavailable"}) + + +def _run_repair_pipeline( + *, + upload_name: str, + source_bytes: bytes, + allow_llm: bool, +) -> tuple[list[VerifiedFix], RepairTransaction]: + """Run the real dry-run repair pipeline inside a temporary workspace.""" + with tempfile.TemporaryDirectory() as tmpdir: + temp_root = Path(tmpdir) + upload_path = temp_root / upload_name + upload_path.write_bytes(source_bytes) + + result = run_repair_pipeline( + RepairPipelineRequest( + source_path=upload_path, + mode="dry_run", + schema=None, + create_dry_run_transaction=True, + allow_llm=allow_llm, + ) + ) + if result.transaction is None: + raise RuntimeError(result.receipt.reason) + return result.fixes, result.transaction + + +@app.get("/") +async def root() -> dict[str, Any]: + """Return service metadata for humans and uptime probes.""" + return { + "service": "DataForge Playground API", + "status": "ok", + "docs_url": "/api/docs", + "frontend_hosting": "cloudflare_static_assets", + } @app.get("/api/health") -async def health() -> dict[str, str]: - """Health check for cold-start warming and uptime monitoring.""" - return {"status": "ok"} +async def health() -> dict[str, Any]: + """Return backend readiness plus UI-facing capability metadata.""" + return { + "status": "ok", + "advanced_available": _advanced_available(), + "max_upload_bytes": MAX_UPLOAD_BYTES, + } @app.get("/api/samples/{name}") async def get_sample(name: str) -> StreamingResponse: - """Serve a bundled sample CSV by name. - - Args: - name: Sample identifier (e.g. 'hospital_10rows'). - """ + """Return a bundled sample CSV by name.""" if name not in ALLOWED_SAMPLES: raise HTTPException( status_code=404, @@ -295,126 +456,83 @@ async def get_sample(name: str) -> StreamingResponse: logger.error("Sample file missing on disk: %s", csv_path) raise HTTPException(status_code=500, detail={"error": "sample_file_missing"}) - csv_bytes = csv_path.read_bytes() return StreamingResponse( - io.BytesIO(csv_bytes), + io.BytesIO(csv_path.read_bytes()), media_type="text/csv", - headers={ - "Content-Disposition": f'attachment; filename="{name}.csv"', - }, + headers={"Content-Disposition": f'attachment; filename="{name}.csv"'}, ) @app.post("/api/profile") @limiter.limit("10/minute") async def profile(request: Request, file: UploadFile) -> dict[str, Any]: - """Profile an uploaded CSV and return detected data-quality issues. - - Heuristic-only by default. Pass `advanced=true` as a query parameter - to use LLM-backed detection (requires a provider key in Space Secrets). - """ - advanced = request.query_params.get("advanced", "false").lower() == "true" - if advanced: - has_key = bool( - os.environ.get("GROQ_API_KEY") or os.environ.get("GEMINI_API_KEY") - ) - if not has_key: - raise HTTPException( - status_code=400, - detail={"error": "advanced_mode_unavailable"}, - ) - - data = await _read_upload(file) - logger.info("Profile request: %d bytes, filename=%s", len(data), file.filename) + """Profile an uploaded CSV and return the detected issues.""" + advanced_requested = request.query_params.get("advanced", "false").lower() == "true" + _require_advanced_mode(advanced_requested) + + source_bytes = await _read_upload(file) + upload_name = Path(file.filename or "upload.csv").name + logger.info( + "Profile request: filename=%s bytes=%d advanced=%s", + upload_name, + len(source_bytes), + advanced_requested, + ) - df = _csv_to_df(data) - issues = run_all_detectors(df, schema=None) + try: + df = _csv_to_df(source_bytes) + issues = run_all_detectors(df, schema=None) + except HTTPException: + raise + except Exception as exc: + logger.exception("Profile endpoint failed") + raise HTTPException( + status_code=500, + detail={ + "error": "profile_failed", + "message": "The profile pipeline could not complete safely.", + }, + ) from exc - return _issues_to_response(issues, df) + return _issues_to_response(issues, df, advanced_requested=advanced_requested) @app.post("/api/repair") @limiter.limit("10/minute") async def repair(request: Request, file: UploadFile) -> dict[str, Any]: - """Propose repairs for an uploaded CSV (dry-run only in playground). - - The repair pipeline runs Safety -> Verifier -> Transaction in a - per-request TemporaryDirectory. The transaction journal is returned - inline and discarded after the response. - """ + """Return dry-run repair proposals plus an ephemeral transaction journal.""" dry_run = request.query_params.get("dry_run", "true").lower() == "true" - advanced = request.query_params.get("advanced", "false").lower() == "true" - - if advanced: - has_key = bool( - os.environ.get("GROQ_API_KEY") or os.environ.get("GEMINI_API_KEY") - ) - if not has_key: - raise HTTPException( - status_code=400, - detail={"error": "advanced_mode_unavailable"}, - ) - - data = await _read_upload(file) - logger.info("Repair request: %d bytes, filename=%s, dry_run=%s", len(data), file.filename, dry_run) - - df = _csv_to_df(data) - source_sha256 = sha256_bytes(data) - - # Detect issues - issues = run_all_detectors(df, schema=None) + advanced_requested = request.query_params.get("advanced", "false").lower() == "true" + + if not dry_run: + raise HTTPException(status_code=400, detail={"error": "apply_not_supported"}) + _require_advanced_mode(advanced_requested) + + source_bytes = await _read_upload(file) + upload_name = Path(file.filename or "upload.csv").name + logger.info( + "Repair request: filename=%s bytes=%d advanced=%s", + upload_name, + len(source_bytes), + advanced_requested, + ) - # Propose fixes (heuristic-only, no LLM unless advanced + keyed) try: - with tempfile.TemporaryDirectory() as tmpdir: - cache_dir = Path(tmpdir) / "cache" - cache_dir.mkdir() - - fixes = propose_fixes( - issues, - df, - schema=None, - cache_dir=cache_dir, - allow_llm=False, - model="gemini-2.0-flash", - ) - - # Run safety filter on each proposed fix - try: - safety = SafetyFilter() - context = SafetyContext() - accepted_fixes: list[ProposedFix] = [] - for fix in fixes: - result = safety.evaluate(fix, schema=None, context=context) - if result.verdict == SafetyVerdict.ALLOW: - accepted_fixes.append(fix) - except Exception: - # Constitution file may not be at the expected path in Docker; - # gracefully skip safety filter and return all fixes. - logger.warning("SafetyFilter init failed; returning unfiltered fixes", exc_info=True) - accepted_fixes = list(fixes) - - # Generate ephemeral transaction ID - txn_id = generate_txn_id() - - return _fixes_to_response(accepted_fixes, txn_id, source_sha256) - except Exception: + fixes, transaction = _run_repair_pipeline( + upload_name=upload_name, + source_bytes=source_bytes, + allow_llm=advanced_requested, + ) + except HTTPException: + raise + except Exception as exc: logger.exception("Repair endpoint failed") raise HTTPException( status_code=500, - detail={"error": "repair_failed", "message": "An internal error occurred during repair."}, - ) - + detail={ + "error": "repair_failed", + "message": "The repair pipeline could not complete safely.", + }, + ) from exc -# --------------------------------------------------------------------------- -# Frontend catch-all (must be LAST route — serves index.html for the SPA) -# --------------------------------------------------------------------------- - - -@app.get("/") -async def root() -> FileResponse: - """Serve the frontend SPA from the co-located web directory.""" - index = WEB_DIR / "index.html" - if index.exists(): - return FileResponse(str(index), media_type="text/html") - return FileResponse(str(Path(__file__).parent / "index_fallback.html"), media_type="text/html") + return _fixes_to_response(fixes, transaction, source_name=upload_name) diff --git a/playground/api/requirements.txt b/playground/api/requirements.txt index 3ff25cc7285eab350a42fc0d0bc9e654c8af47a2..21dd78821fe4a2a4e3042c01e3a0ba3de47990df 100644 --- a/playground/api/requirements.txt +++ b/playground/api/requirements.txt @@ -1,7 +1,9 @@ # Playground API dependencies — isolated from core pyproject.toml runtime deps. # These are installed ONLY inside the HF Docker Space and in CI for smoke tests. # See ARCHITECTURE.md §3 for justification of each dependency. -fastapi==0.115.12 -uvicorn[standard]==0.34.2 +fastapi==0.136.1 +starlette==0.49.3 +uvicorn[standard]==0.35.0 slowapi==0.1.9 -python-multipart==0.0.20 +python-multipart==0.0.27 +pandas==2.3.3 diff --git a/playground/web/app.js b/playground/web/app.js deleted file mode 100644 index aed7b95f580660e3125ade711986740bb4a37b80..0000000000000000000000000000000000000000 --- a/playground/web/app.js +++ /dev/null @@ -1,379 +0,0 @@ -/** - * DataForge Playground — vanilla ES module frontend. - * - * Zero localStorage / sessionStorage (rule 0.4.4). - * Zero external JS dependencies beyond Pico.css (CDN, CSS-only). - * - * BACKEND_URL is replaced at build time via sed for Cloudflare Pages deploy. - * For local dev, change it to http://localhost:7860. - */ - -// --------------------------------------------------------------------------- -// Configuration -// --------------------------------------------------------------------------- - -// BACKEND_URL: empty string = same-origin (frontend served from HF Space). -// For separate Cloudflare Pages deploy, replace with the HF Space URL via sed: -// sed -i "s|const BACKEND_URL = \"\"|const BACKEND_URL = \"https://...\"|g" app.js -const BACKEND_URL = ""; - -const HEALTH_TIMEOUT_MS = 3000; -const HEALTH_MAX_RETRIES = 8; -const HEALTH_INITIAL_DELAY_MS = 1000; - -// --------------------------------------------------------------------------- -// DOM references -// --------------------------------------------------------------------------- - -const statusBanner = document.getElementById("status-banner"); -const statusText = document.getElementById("status-text"); -const warmupProgress = document.getElementById("warmup-progress"); - -const csvUpload = document.getElementById("csv-upload"); -const sampleSelect = document.getElementById("sample-select"); -const advancedToggle = document.getElementById("advanced-toggle"); -const profileBtn = document.getElementById("profile-btn"); -const repairBtn = document.getElementById("repair-btn"); - -const resultsSection = document.getElementById("results-section"); -const tabProfile = document.getElementById("tab-profile"); -const tabRepair = document.getElementById("tab-repair"); -const tabRevert = document.getElementById("tab-revert"); - -const panelProfile = document.getElementById("panel-profile"); -const panelRepair = document.getElementById("panel-repair"); -const panelRevert = document.getElementById("panel-revert"); - -const profileLoading = document.getElementById("profile-loading"); -const profileResults = document.getElementById("profile-results"); -const repairLoading = document.getElementById("repair-loading"); -const repairResults = document.getElementById("repair-results"); -const revertJournal = document.getElementById("revert-journal"); - -// --------------------------------------------------------------------------- -// State (in-memory only — no browser storage) -// --------------------------------------------------------------------------- - -let currentFile = null; -let backendReady = false; - -// --------------------------------------------------------------------------- -// Cold-start health check with exponential backoff -// --------------------------------------------------------------------------- - -async function checkHealth() { - showBanner("Warming up the backend...", true); - let delay = HEALTH_INITIAL_DELAY_MS; - - for (let attempt = 0; attempt < HEALTH_MAX_RETRIES; attempt++) { - try { - const controller = new AbortController(); - const timeout = setTimeout(() => controller.abort(), HEALTH_TIMEOUT_MS); - - const response = await fetch(`${BACKEND_URL}/api/health`, { - signal: controller.signal, - }); - clearTimeout(timeout); - - if (response.ok) { - backendReady = true; - hideBanner(); - enableControls(); - return; - } - } catch { - // Network error or timeout — retry - } - - statusText.textContent = `Warming up the backend... (attempt ${attempt + 2}/${HEALTH_MAX_RETRIES})`; - await sleep(delay); - delay = Math.min(delay * 2, 8000); - } - - showBanner("Backend is unavailable. The Space may be sleeping — try refreshing in 30 seconds.", false); - warmupProgress.style.display = "none"; -} - -// --------------------------------------------------------------------------- -// Tab switching -// --------------------------------------------------------------------------- - -const tabs = [tabProfile, tabRepair, tabRevert]; -const panels = [panelProfile, panelRepair, panelRevert]; - -tabs.forEach((tab, index) => { - tab.addEventListener("click", () => { - tabs.forEach((t) => { - t.setAttribute("aria-selected", "false"); - t.classList.remove("tab-active"); - }); - panels.forEach((p) => p.classList.add("tab-panel--hidden")); - - tab.setAttribute("aria-selected", "true"); - tab.classList.add("tab-active"); - panels[index].classList.remove("tab-panel--hidden"); - }); -}); - -// --------------------------------------------------------------------------- -// File handling -// --------------------------------------------------------------------------- - -csvUpload.addEventListener("change", () => { - const file = csvUpload.files[0]; - if (file) { - currentFile = file; - sampleSelect.value = ""; - updateButtons(); - } -}); - -sampleSelect.addEventListener("change", async () => { - const sampleName = sampleSelect.value; - if (!sampleName) return; - - try { - const response = await fetch(`${BACKEND_URL}/api/samples/${sampleName}`); - if (!response.ok) throw new Error(`Failed to fetch sample: ${response.status}`); - const blob = await response.blob(); - currentFile = new File([blob], `${sampleName}.csv`, { type: "text/csv" }); - csvUpload.value = ""; - updateButtons(); - } catch (err) { - showError(profileResults, `Failed to load sample: ${err.message}`); - } -}); - -function updateButtons() { - const hasFile = currentFile !== null; - profileBtn.disabled = !hasFile || !backendReady; - repairBtn.disabled = !hasFile || !backendReady; -} - -function enableControls() { - updateButtons(); -} - -// --------------------------------------------------------------------------- -// Profile -// --------------------------------------------------------------------------- - -profileBtn.addEventListener("click", async () => { - if (!currentFile || !backendReady) return; - - showTab(0); - resultsSection.classList.remove("results-hidden"); - showLoading(profileLoading, true); - profileResults.innerHTML = ""; - - const formData = new FormData(); - formData.append("file", currentFile); - - const params = new URLSearchParams(); - if (advancedToggle.checked) params.set("advanced", "true"); - - try { - const url = `${BACKEND_URL}/api/profile${params.toString() ? "?" + params : ""}`; - const response = await fetch(url, { method: "POST", body: formData }); - - if (response.status === 413) { - showError(profileResults, "File too large. Maximum upload size is 1 MB."); - return; - } - if (response.status === 400) { - const data = await response.json(); - if (data.detail?.error === "advanced_mode_unavailable") { - showError(profileResults, "Advanced mode is unavailable. No LLM provider key is configured on the backend."); - advancedToggle.checked = false; - advancedToggle.disabled = true; - advancedToggle.title = "Advanced mode requires an LLM provider key configured in the backend."; - return; - } - } - if (!response.ok) { - throw new Error(`Server returned ${response.status}`); - } - - const data = await response.json(); - renderProfileResults(data); - } catch (err) { - showError(profileResults, `Profile failed: ${err.message}`); - } finally { - showLoading(profileLoading, false); - } -}); - -function renderProfileResults(data) { - const { issues, meta } = data; - - let html = `
- ${meta.rows} rows · ${meta.columns} columns · - ${meta.total_issues} issue${meta.total_issues !== 1 ? "s" : ""} detected -
`; - - if (issues.length === 0) { - html += `

No data-quality issues detected.

`; - } else { - html += `
- - - - - - - - - - `; - - for (const issue of issues) { - const badge = severityBadge(issue.severity); - const rows = issue.row_indices.length <= 5 - ? issue.row_indices.join(", ") - : issue.row_indices.slice(0, 5).join(", ") + "..."; - html += ` - - - - - - `; - } - - html += `
ColumnIssue TypeSeverityRows AffectedCount
${escapeHtml(issue.column)}${escapeHtml(issue.issue_type)}${badge}${rows}${issue.count}
`; - } - - profileResults.innerHTML = html; -} - -// --------------------------------------------------------------------------- -// Repair -// --------------------------------------------------------------------------- - -repairBtn.addEventListener("click", async () => { - if (!currentFile || !backendReady) return; - - showTab(1); - resultsSection.classList.remove("results-hidden"); - showLoading(repairLoading, true); - repairResults.innerHTML = ""; - - const formData = new FormData(); - formData.append("file", currentFile); - - const params = new URLSearchParams({ dry_run: "true" }); - if (advancedToggle.checked) params.set("advanced", "true"); - - try { - const response = await fetch(`${BACKEND_URL}/api/repair?${params}`, { - method: "POST", - body: formData, - }); - - if (response.status === 413) { - showError(repairResults, "File too large. Maximum upload size is 1 MB."); - return; - } - if (!response.ok) throw new Error(`Server returned ${response.status}`); - - const data = await response.json(); - renderRepairResults(data); - renderRevertJournal(data.txn_journal); - } catch (err) { - showError(repairResults, `Repair failed: ${err.message}`); - } finally { - showLoading(repairLoading, false); - } -}); - -function renderRepairResults(data) { - const { fixes } = data; - - if (fixes.length === 0) { - repairResults.innerHTML = `

No repairs proposed. The data looks clean.

`; - return; - } - - let html = `
- ${fixes.length} repair${fixes.length !== 1 ? "s" : ""} proposed (dry run) -
`; - - // Unified diff view - html += `
`;
-    for (const fix of fixes) {
-        html += `--- Row ${fix.row}, Column: ${escapeHtml(fix.column)} (${escapeHtml(fix.detector_id)})\n`;
-        html += `- ${escapeHtml(fix.old_value)}\n`;
-        html += `+ ${escapeHtml(fix.new_value)}\n`;
-        html += `  ${escapeHtml(fix.reason)}\n\n`;
-    }
-    html += `
`; - - repairResults.innerHTML = html; -} - -function renderRevertJournal(journal) { - if (!journal) { - revertJournal.textContent = "No transaction journal available."; - return; - } - revertJournal.textContent = JSON.stringify(journal, null, 2); - revertJournal.classList.remove("journal-empty"); -} - -// --------------------------------------------------------------------------- -// Utilities -// --------------------------------------------------------------------------- - -function showTab(index) { - tabs.forEach((t, i) => { - t.setAttribute("aria-selected", i === index ? "true" : "false"); - t.classList.toggle("tab-active", i === index); - }); - panels.forEach((p, i) => { - p.classList.toggle("tab-panel--hidden", i !== index); - }); -} - -function showBanner(message, showProgress) { - statusText.textContent = message; - warmupProgress.style.display = showProgress ? "block" : "none"; - statusBanner.classList.remove("banner--hidden"); - statusBanner.classList.add("banner--visible"); -} - -function hideBanner() { - statusBanner.classList.add("banner--hidden"); - statusBanner.classList.remove("banner--visible"); -} - -function showLoading(el, show) { - el.classList.toggle("loading-hidden", !show); - el.classList.toggle("loading-visible", show); -} - -function showError(container, message) { - container.innerHTML = `

${escapeHtml(message)}

`; -} - -function severityBadge(severity) { - const cls = `badge badge--${severity}`; - return `${severity.toUpperCase()}`; -} - -function escapeHtml(text) { - const div = document.createElement("div"); - div.textContent = text; - return div.innerHTML; -} - -function sleep(ms) { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - -// --------------------------------------------------------------------------- -// Initialize -// --------------------------------------------------------------------------- - -document.addEventListener("DOMContentLoaded", () => { - checkHealth(); -}); diff --git a/playground/web/index.html b/playground/web/index.html deleted file mode 100644 index e65f00ebc4461587ecded10aae5c16fb15b2d5da..0000000000000000000000000000000000000000 --- a/playground/web/index.html +++ /dev/null @@ -1,136 +0,0 @@ - - - - - - - DataForge Playground - - - - - - - - -
-
-

DataForge Playground

-

Upload a CSV. Profile it for quality issues. Preview repairs. No install, no signup.

-
- -
- -
- -
-
-
- -
-
- -
-
-
-
- -
-
- - -
-
-
- - -
- - -
-
- -

Analyzing your data...

-
-
-
- -
-
- -

Proposing repairs...

-
-
-
- -
-
-
-
Playground is stateless
-

Transaction journals shown here are ephemeral and will not be persisted. - Install the CLI to use dataforge repair --apply and - dataforge revert <txn_id>.

-
-

-                
-
-
- - -
- - Your file is processed in memory and discarded after the response. - No cookies, no analytics of file contents. - -
- - -
- Run locally instead -
pip install dataforge
-dataforge profile your_data.csv
-dataforge repair your_data.csv --dry-run
-

See the repository for full documentation.

-
-
- - - - - - - - diff --git a/playground/web/style.css b/playground/web/style.css deleted file mode 100644 index a56bb667d538ef1b64cd3574557826aa6c737586..0000000000000000000000000000000000000000 --- a/playground/web/style.css +++ /dev/null @@ -1,307 +0,0 @@ -/* - * DataForge Playground — custom style overrides on Pico.css. - * - * Minimal additions for: severity badges, diff colors, progress bar - * polish, cold-start banner, and dark-mode compatibility. - */ - -/* --------------------------------------------------------------------------- - Layout & typography - --------------------------------------------------------------------------- */ - -:root { - --df-safe: #22c55e; - --df-review: #f59e0b; - --df-unsafe: #ef4444; - --df-diff-add-bg: rgba(34, 197, 94, 0.12); - --df-diff-del-bg: rgba(239, 68, 68, 0.12); - --df-diff-add-text: #16a34a; - --df-diff-del-text: #dc2626; - --df-banner-bg: #1e293b; - --df-banner-text: #e2e8f0; -} - -@media (prefers-color-scheme: dark) { - :root { - --df-diff-add-bg: rgba(34, 197, 94, 0.18); - --df-diff-del-bg: rgba(239, 68, 68, 0.18); - --df-diff-add-text: #4ade80; - --df-diff-del-text: #f87171; - --df-banner-bg: #0f172a; - --df-banner-text: #cbd5e1; - } -} - -body { - min-height: 100vh; - display: flex; - flex-direction: column; -} - -main { - flex: 1; -} - -footer { - text-align: center; - padding: 1rem 0; - opacity: 0.7; -} - -/* --------------------------------------------------------------------------- - Cold-start banner - --------------------------------------------------------------------------- */ - -.banner { - background: var(--df-banner-bg); - color: var(--df-banner-text); - padding: 0.75rem 1rem; - border-radius: 0.5rem; - text-align: center; - transition: opacity 0.3s ease, max-height 0.3s ease; - margin-bottom: 1rem; -} - -.banner--hidden { - opacity: 0; - max-height: 0; - overflow: hidden; - padding: 0; - margin: 0; - pointer-events: none; -} - -.banner--visible { - opacity: 1; - max-height: 6rem; -} - -.banner progress { - width: 100%; - max-width: 20rem; - margin-top: 0.5rem; -} - -/* --------------------------------------------------------------------------- - Severity badges - --------------------------------------------------------------------------- */ - -.badge { - display: inline-block; - padding: 0.15em 0.55em; - border-radius: 0.35em; - font-size: 0.8em; - font-weight: 600; - letter-spacing: 0.03em; - text-transform: uppercase; -} - -.badge--safe { - background: var(--df-safe); - color: #fff; -} - -.badge--review { - background: var(--df-review); - color: #fff; -} - -.badge--unsafe { - background: var(--df-unsafe); - color: #fff; -} - -/* --------------------------------------------------------------------------- - Diff view - --------------------------------------------------------------------------- */ - -.diff-view { - font-family: "Fira Code", "Cascadia Code", "JetBrains Mono", monospace; - font-size: 0.875rem; - line-height: 1.6; - padding: 1rem; - border-radius: 0.5rem; - overflow-x: auto; - white-space: pre-wrap; - word-break: break-word; -} - -.diff-header { - color: var(--pico-color); - font-weight: 600; - display: block; -} - -.addition { - background: var(--df-diff-add-bg); - color: var(--df-diff-add-text); - display: block; -} - -.deletion { - background: var(--df-diff-del-bg); - color: var(--df-diff-del-text); - display: block; -} - -.diff-reason { - opacity: 0.65; - font-style: italic; - display: block; -} - -/* --------------------------------------------------------------------------- - Results section - --------------------------------------------------------------------------- */ - -.results-hidden { - display: none; -} - -.meta-summary { - padding: 0.5rem 0; - margin-bottom: 0.5rem; - font-size: 0.95rem; -} - -.no-issues { - padding: 1rem; - text-align: center; - opacity: 0.7; -} - -.table-wrapper { - overflow-x: auto; -} - -.row-indices { - font-family: "Fira Code", monospace; - font-size: 0.85em; -} - -/* --------------------------------------------------------------------------- - Tabs - --------------------------------------------------------------------------- */ - -[role="tablist"] { - display: flex; - gap: 0; - list-style: none; - padding: 0; - margin: 0 0 1rem 0; - border-bottom: 2px solid var(--pico-muted-border-color); -} - -[role="tablist"] li { - margin: 0; - padding: 0; -} - -[role="tab"] { - background: none; - border: none; - border-bottom: 2px solid transparent; - padding: 0.5rem 1rem; - cursor: pointer; - font-weight: 500; - color: var(--pico-muted-color); - transition: color 0.2s, border-color 0.2s; - margin-bottom: -2px; -} - -[role="tab"]:hover { - color: var(--pico-color); -} - -[role="tab"].tab-active { - color: var(--pico-primary); - border-bottom-color: var(--pico-primary); -} - -.tab-panel--hidden { - display: none; -} - -/* --------------------------------------------------------------------------- - Loading states - --------------------------------------------------------------------------- */ - -.loading-hidden { - display: none; -} - -.loading-visible { - display: block; - text-align: center; - padding: 1rem; -} - -.loading-visible progress { - width: 100%; - max-width: 20rem; - margin: 0 auto; -} - -/* --------------------------------------------------------------------------- - Error card - --------------------------------------------------------------------------- */ - -.error-card { - border-left: 4px solid var(--df-unsafe); - padding: 0.75rem 1rem; -} - -/* --------------------------------------------------------------------------- - Revert notice - --------------------------------------------------------------------------- */ - -.revert-notice { - border-left: 4px solid var(--df-review); -} - -.journal-empty { - color: var(--pico-muted-color); - font-style: italic; -} - -/* --------------------------------------------------------------------------- - Privacy notice - --------------------------------------------------------------------------- */ - -.privacy-notice { - display: block; - text-align: center; - padding: 0.75rem 0; - opacity: 0.6; -} - -/* --------------------------------------------------------------------------- - Toggle label - --------------------------------------------------------------------------- */ - -.toggle-label { - display: flex; - align-items: center; - gap: 0.5rem; - cursor: pointer; -} - -.toggle-label input[disabled] { - cursor: not-allowed; -} - -/* --------------------------------------------------------------------------- - Upload section - --------------------------------------------------------------------------- */ - -#upload-section { - margin-bottom: 1.5rem; -} - -/* --------------------------------------------------------------------------- - Run locally details - --------------------------------------------------------------------------- */ - -#run-locally { - margin-top: 2rem; -} diff --git a/pyproject.toml b/pyproject.toml index 9f896eaf3d322822556da48c4eeddecba30295c3..6458adcdeeb9ab5293c7bfe080e3b1a1f89df756 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,36 +1,40 @@ [project] -name = "dataforge" +name = "dataforge15" version = "0.1.0" -description = "Scaffold for a CLI-first data-quality repair agent." +description = "DataForge15: CLI-first data-quality detection and reversible repair for tabular data." readme = "README.md" -license = { text = "Apache-2.0" } +license = "Apache-2.0" requires-python = ">=3.11,<3.13" keywords = ["data-quality", "ai-agent", "llm", "rl", "smt", "dbt"] classifiers = [ "Development Status :: 3 - Alpha", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ] dependencies = [ - "pandas>=2.2", - "pyarrow>=16.0", "pydantic>=2.7", - "typer>=0.12", + "typer>=0.24,<0.25", "rich>=13.7", "z3-solver>=4.13", - "networkx>=3.3", - "causal-learn>=0.1.4", - "scipy>=1.13", - "httpx>=0.27", - "tenacity>=8.3", - "python-dotenv>=1.0", "pyyaml>=6.0", - "sqlglot>=25.0", - "duckdb>=1.0", ] [project.optional-dependencies] +bench = [ + "pandas>=2.2", + "httpx>=0.27", + "tenacity>=8.3", + "python-dotenv>=1.0", + "pyarrow>=16.0", +] +causal = [ + "pandas>=2.2", + "numpy>=1.26", + "networkx>=3.3", + "causal-learn>=0.1.4", + "hyppo>=0.5.2", + "scipy>=1.13", +] dev = [ "pytest>=8.2", "pytest-cov>=5.0", @@ -38,33 +42,76 @@ dev = [ "pytest-xdist>=3.6", "hypothesis>=6.100", "mutmut>=3.5", + "build>=1.2", + "pip-audit>=2.10,<3", + "cyclonedx-bom>=7.3,<8", + "idna>=3.15", + "pip>=26.1.1", + "urllib3>=2.7", "ruff>=0.11", "mypy>=1.10", "pandas-stubs>=2.2", "types-PyYAML", + "huggingface_hub==1.13.0", + "httpx>=0.27", + "tenacity>=8.3", + "python-dotenv>=1.0", + "pyarrow>=16.0", + "networkx>=3.3", + "causal-learn>=0.1.4", + "hyppo>=0.5.2", + "scipy>=1.13", + "sqlglot>=25.0", + "duckdb>=1.0", ] train = [ - "transformers>=4.44", - "datasets>=2.20", - "trl>=1.0", - "accelerate>=0.33", - "peft>=0.12", - "bitsandbytes>=0.43", + "trl==1.4.0", + "transformers==5.7.0", + "accelerate==1.13.0", + "peft==0.19.1", + "bitsandbytes==0.49.2", + "datasets==4.8.5", + "huggingface_hub==1.13.0", + "pyyaml==6.0.3", + "pandas==2.3.3", + "tensorboard==2.20.0", ] eval = [ "matplotlib>=3.9", "seaborn>=0.13", ] +providers = [ + "httpx>=0.27", + "tenacity>=8.3", + "python-dotenv>=1.0", +] +pandas = [ + "pandas>=2.2", +] playground = [ - "fastapi>=0.111", - "uvicorn[standard]>=0.30", - "python-multipart>=0.0.9", + "pandas>=2.2", + "fastapi>=0.136.1", + "starlette>=0.49.1,<2", + "uvicorn[standard]>=0.35", + "python-multipart>=0.0.27", + "slowapi>=0.1.9", +] +openenv = [ + "pandas>=2.2", + "openenv-core[core]>=0.2.2", + "duckdb>=1.0", + "sqlglot>=25.0", + "scipy>=1.13", + "networkx>=3.3", + "causal-learn>=0.1.4", + "hyppo>=0.5.2", ] all = [ - "dataforge[dev,train,eval,playground]", + "dataforge15[bench,causal,dev,eval,pandas,playground,providers,train,openenv]", ] [project.scripts] +dataforge15 = "dataforge.cli:app" dataforge = "dataforge.cli:app" [build-system] @@ -73,16 +120,31 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] where = ["."] -include = ["dataforge*"] +include = ["dataforge", "dataforge.*", "data_quality_env", "data_quality_env.*"] + +[tool.setuptools.package-data] +dataforge = [ + "py.typed", + "fixtures/*.csv", + "fixtures/*.yaml", + "datasets/embedded/**/*.csv", + "safety/constitutions/*.yaml", + "safety/adversarial/*.yaml", +] [tool.ruff] line-length = 100 target-version = "py311" +extend-exclude = [".hf-space-repo", ".hf-space-stage", ".hf-space-stage-plan"] [tool.ruff.lint] select = ["E", "F", "W", "I", "N", "UP", "B", "A", "C4", "PIE", "RET", "SIM"] ignore = ["E501"] +[tool.ruff.lint.per-file-ignores] +"data_quality_env/**/*.py" = ["B007", "B027", "E402", "E731", "F401", "F541", "F841", "I001", "N", "RET", "SIM", "UP"] +"training/kaggle/sft_warmup_kaggle.ipynb" = ["E402"] + [tool.ruff.format] quote-style = "double" indent-style = "space" @@ -97,10 +159,17 @@ warn_unused_ignores = true disallow_untyped_defs = true explicit_package_bases = true exclude = [ + "^\\.hf-space-repo/", + "^\\.hf-space-stage/", + "^\\.hf-space-stage-plan/", "^[^/]*\\.py$", # loose root-level scripts (hackathon legacy) - "^(server|training|playground|benchmark_results|datasets)/", + "^(training|playground|benchmark_results|datasets)/", ] +[[tool.mypy.overrides]] +module = ["data_quality_env.*"] +ignore_errors = true + [tool.pytest.ini_options] minversion = "8.0" addopts = "-ra --strict-markers --strict-config"