Deploy DataForge playground API
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +63 -28
- README.md +21 -18
- README_MAIN.md +0 -96
- dataforge/__init__.py +115 -2
- dataforge/agent/__init__.py +16 -1
- dataforge/agent/providers.py +11 -3
- dataforge/agent/scratchpad.py +183 -0
- dataforge/agent/tool_actions.py +343 -0
- dataforge/bench/core.py +6 -0
- dataforge/bench/groq_client.py +306 -27
- dataforge/bench/methods.py +35 -0
- dataforge/bench/report.py +19 -13
- dataforge/bench/runner.py +45 -6
- dataforge/causal/__init__.py +21 -1
- dataforge/causal/dag.py +174 -0
- dataforge/causal/pc.py +232 -0
- dataforge/causal/root_cause.py +193 -0
- dataforge/cli/__init__.py +10 -4
- dataforge/cli/audit.py +70 -0
- dataforge/cli/bench.py +23 -4
- dataforge/cli/common.py +26 -4
- dataforge/cli/profile.py +61 -16
- dataforge/cli/release.py +39 -0
- dataforge/cli/repair.py +104 -249
- dataforge/cli/watch.py +142 -0
- dataforge/datasets/embedded/hospital/clean.csv +11 -0
- dataforge/datasets/embedded/hospital/dirty.csv +11 -0
- dataforge/datasets/real_world.py +37 -7
- dataforge/detectors/__init__.py +2 -4
- dataforge/detectors/base.py +5 -5
- dataforge/detectors/decimal_shift.py +11 -17
- dataforge/detectors/fd_violation.py +21 -24
- dataforge/detectors/type_mismatch.py +6 -13
- dataforge/engine/__init__.py +33 -1
- dataforge/engine/repair.py +670 -0
- dataforge/env/__init__.py +22 -1
- dataforge/env/environment.py +884 -0
- dataforge/env/observation.py +61 -0
- dataforge/env/openenv_core.py +146 -0
- dataforge/env/reward.py +128 -0
- dataforge/env/server.py +175 -0
- dataforge/evaluation_contract.py +76 -0
- dataforge/fixtures/hospital_10rows.csv +11 -0
- dataforge/fixtures/hospital_schema.yaml +17 -0
- dataforge/http/__init__.py +1 -0
- dataforge/http/problem.py +99 -0
- dataforge/integrations/dbt.py +1 -0
- dataforge/observability.py +76 -0
- dataforge/py.typed +1 -0
- dataforge/release/__init__.py +2 -0
Dockerfile
CHANGED
|
@@ -1,28 +1,63 @@
|
|
| 1 |
-
# DataForge Playground
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
RUN
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
COPY playground/api/
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DataForge Playground — Multi-stage Docker build for HF Spaces.
|
| 2 |
+
#
|
| 3 |
+
# Target: <= 600 MB image. Runs as non-root UID 1000 (HF requirement).
|
| 4 |
+
# Single-worker uvicorn with --timeout-keep-alive 5 (slowloris mitigation).
|
| 5 |
+
#
|
| 6 |
+
# See specs/SPEC_playground.md §4 and §6.5.
|
| 7 |
+
|
| 8 |
+
# ============================================================
|
| 9 |
+
# Stage 1: builder — install all Python dependencies
|
| 10 |
+
# ============================================================
|
| 11 |
+
FROM python:3.12-slim AS builder
|
| 12 |
+
|
| 13 |
+
WORKDIR /build
|
| 14 |
+
|
| 15 |
+
# System deps for building wheels
|
| 16 |
+
RUN apt-get update && \
|
| 17 |
+
apt-get install -y --no-install-recommends gcc g++ && \
|
| 18 |
+
rm -rf /var/lib/apt/lists/*
|
| 19 |
+
|
| 20 |
+
# Install playground API requirements
|
| 21 |
+
COPY playground/api/requirements.txt /build/requirements.txt
|
| 22 |
+
RUN pip install --no-cache-dir -r /build/requirements.txt
|
| 23 |
+
|
| 24 |
+
# Copy dataforge source and install it
|
| 25 |
+
COPY pyproject.toml /build/dataforge_src/pyproject.toml
|
| 26 |
+
COPY README.md /build/dataforge_src/README.md
|
| 27 |
+
COPY dataforge/ /build/dataforge_src/dataforge/
|
| 28 |
+
COPY constitutions/ /build/dataforge_src/constitutions/
|
| 29 |
+
RUN pip install --no-cache-dir /build/dataforge_src
|
| 30 |
+
|
| 31 |
+
# ============================================================
|
| 32 |
+
# Stage 2: runtime — minimal image with only installed packages
|
| 33 |
+
# ============================================================
|
| 34 |
+
FROM python:3.12-slim
|
| 35 |
+
|
| 36 |
+
# HF Spaces requires non-root user with UID 1000
|
| 37 |
+
RUN useradd -m -u 1000 user
|
| 38 |
+
|
| 39 |
+
# Copy installed Python packages from builder
|
| 40 |
+
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
| 41 |
+
COPY --from=builder /usr/local/bin /usr/local/bin
|
| 42 |
+
|
| 43 |
+
# Copy constitutions to the site-packages-relative path SafetyFilter expects.
|
| 44 |
+
COPY --from=builder /build/dataforge_src/constitutions /usr/local/lib/python3.12/site-packages/constitutions
|
| 45 |
+
|
| 46 |
+
# Copy application code
|
| 47 |
+
COPY playground/api/app.py /home/user/app/app.py
|
| 48 |
+
COPY playground/api/samples/ /home/user/app/samples/
|
| 49 |
+
|
| 50 |
+
# Switch to non-root user
|
| 51 |
+
USER user
|
| 52 |
+
WORKDIR /home/user/app
|
| 53 |
+
|
| 54 |
+
# Expose the port HF Spaces expects
|
| 55 |
+
EXPOSE 7860
|
| 56 |
+
|
| 57 |
+
# Environment
|
| 58 |
+
ENV PORT=7860
|
| 59 |
+
ENV DATAFORGE_PLAYGROUND_DEV=0
|
| 60 |
+
|
| 61 |
+
# Start uvicorn with single worker (slowapi in-memory limiter contract)
|
| 62 |
+
# and honor PORT for Hugging Face runtime assignment.
|
| 63 |
+
CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860} --workers 1 --timeout-keep-alive 5"]
|
README.md
CHANGED
|
@@ -7,37 +7,40 @@ sdk: docker
|
|
| 7 |
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
license: apache-2.0
|
| 10 |
-
short_description:
|
| 11 |
---
|
| 12 |
|
| 13 |
-
# DataForge Playground
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
|
| 20 |
-
-
|
| 21 |
-
|
| 22 |
-
-
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
|
| 27 |
-
-
|
| 28 |
-
-
|
| 29 |
-
-
|
|
|
|
| 30 |
|
| 31 |
-
## Run
|
| 32 |
|
| 33 |
```bash
|
| 34 |
-
pip install
|
| 35 |
-
|
| 36 |
-
|
| 37 |
```
|
| 38 |
|
| 39 |
## Source
|
| 40 |
|
| 41 |
-
- Main repository:
|
| 42 |
- Spec: `specs/SPEC_playground.md`
|
| 43 |
- License: Apache-2.0
|
|
|
|
| 7 |
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
license: apache-2.0
|
| 10 |
+
short_description: Profile CSVs and dry-run safe repairs.
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# DataForge Playground API
|
| 14 |
|
| 15 |
+
This is the API backend for the DataForge playground. The browser UI is deployed
|
| 16 |
+
separately through Cloudflare Workers Static Assets; this Hugging Face Docker
|
| 17 |
+
Space serves stateless CSV profiling and dry-run repair endpoints.
|
| 18 |
|
| 19 |
+
## What It Does
|
| 20 |
|
| 21 |
+
- Profile: detects type mismatches, decimal shifts, and functional dependency
|
| 22 |
+
violations.
|
| 23 |
+
- Repair dry run: proposes fixes through SafetyFilter -> SMTVerifier and
|
| 24 |
+
returns an ephemeral transaction receipt without persisting user data.
|
| 25 |
+
- Samples: serves small deterministic CSV examples for the static frontend.
|
| 26 |
|
| 27 |
+
## What It Does Not Do
|
| 28 |
|
| 29 |
+
- It does not persist uploaded files.
|
| 30 |
+
- It does not use cookies or analytics for file contents.
|
| 31 |
+
- It does not call an LLM by default.
|
| 32 |
+
- It does not perform autonomous production repair.
|
| 33 |
|
| 34 |
+
## Run Locally
|
| 35 |
|
| 36 |
```bash
|
| 37 |
+
python -m pip install -e ".[dev]"
|
| 38 |
+
pip install -r playground/api/requirements.txt
|
| 39 |
+
uvicorn playground.api.app:app --reload --port 7860
|
| 40 |
```
|
| 41 |
|
| 42 |
## Source
|
| 43 |
|
| 44 |
+
- Main repository: `github.com/Praneshrajan15/data-quality-env`
|
| 45 |
- Spec: `specs/SPEC_playground.md`
|
| 46 |
- License: Apache-2.0
|
README_MAIN.md
DELETED
|
@@ -1,96 +0,0 @@
|
|
| 1 |
-
# DataForge
|
| 2 |
-
|
| 3 |
-
DataForge currently ships a real Week 3 CLI for CSV profiling and repair.
|
| 4 |
-
|
| 5 |
-
This repository now includes shipped detectors, deterministic repairers,
|
| 6 |
-
constitutional safety gating, SMT-backed structural verification, reversible
|
| 7 |
-
transaction logs, and real-world benchmark infrastructure. The hosted
|
| 8 |
-
playground, warehouse integrations, and trained model family remain future
|
| 9 |
-
work.
|
| 10 |
-
|
| 11 |
-
## Current Status
|
| 12 |
-
|
| 13 |
-
- `dataforge profile`, `dataforge repair`, `dataforge revert`, and `dataforge bench`
|
| 14 |
-
- Three shipped detectors: `type_mismatch`, `decimal_shift`, `fd_violation`
|
| 15 |
-
- Three shipped repairers with safety + verifier gating in the apply path
|
| 16 |
-
- Reversible transaction logs with byte-identical revert via source snapshots
|
| 17 |
-
- Benchmark/report generation infrastructure for Hospital / Flights / Beers
|
| 18 |
-
- `Makefile` targets for setup, lint, type-checking, and tests
|
| 19 |
-
- CI plus unit / integration / property / adversarial coverage
|
| 20 |
-
|
| 21 |
-
## Benchmark Results
|
| 22 |
-
|
| 23 |
-
<!-- BENCH:START -->
|
| 24 |
-
Generated from `eval/results/agent_comparison.json`.
|
| 25 |
-
|
| 26 |
-
| Method | Precision | Recall | F1 | Avg Steps | Quota Units |
|
| 27 |
-
| --- | --- | --- | --- | --- | --- |
|
| 28 |
-
| heuristic | 0.0000 | 0.0000 | 0.0000 | 134.33 | 0.0000 |
|
| 29 |
-
| llm_react | Skipped | Skipped | Skipped | Skipped | Skipped |
|
| 30 |
-
| llm_zeroshot | Skipped | Skipped | Skipped | Skipped | Skipped |
|
| 31 |
-
| random | 0.0038 | 0.0003 | 0.0005 | 150.33 | 0.0000 |
|
| 32 |
-
|
| 33 |
-
See `BENCHMARK_REPORT.md` for per-dataset tables, error bars, and citation-only SOTA rows.
|
| 34 |
-
|
| 35 |
-
Skipped methods in this run: DATAFORGE_LLM_PROVIDER must be set to groq.
|
| 36 |
-
<!-- BENCH:END -->
|
| 37 |
-
|
| 38 |
-
## Local Setup
|
| 39 |
-
|
| 40 |
-
```bash
|
| 41 |
-
make setup
|
| 42 |
-
make lint
|
| 43 |
-
make type
|
| 44 |
-
make test
|
| 45 |
-
```
|
| 46 |
-
|
| 47 |
-
Verification works on Linux, macOS, or Windows (with Git Bash as the
|
| 48 |
-
shell substrate for GNU Make). Requires Python 3.11 or 3.12
|
| 49 |
-
(`requires-python = ">=3.11,<3.13"`).
|
| 50 |
-
|
| 51 |
-
### Windows-specific setup
|
| 52 |
-
|
| 53 |
-
```powershell
|
| 54 |
-
# Install Python 3.12 and GNU Make if not present
|
| 55 |
-
winget install -e --id Python.Python.3.12
|
| 56 |
-
winget install -e --id ezwinports.make
|
| 57 |
-
|
| 58 |
-
# Create and activate a project venv
|
| 59 |
-
py -3.12 -m venv .venv
|
| 60 |
-
.\.venv\Scripts\Activate.ps1
|
| 61 |
-
|
| 62 |
-
# Install dependencies and verify
|
| 63 |
-
python -m pip install -e ".[all]"
|
| 64 |
-
make lint && make type && make test
|
| 65 |
-
```
|
| 66 |
-
|
| 67 |
-
Git for Windows provides the Bash implementation the Makefile uses on Windows.
|
| 68 |
-
Do not rely on `C:\Windows\System32\bash.exe` (WSL).
|
| 69 |
-
|
| 70 |
-
## Environment Variables
|
| 71 |
-
|
| 72 |
-
Future provider keys belong in a root `.env` file that is gitignored and meant
|
| 73 |
-
to be loaded with `python-dotenv`.
|
| 74 |
-
|
| 75 |
-
- `GROQ_API_KEY`
|
| 76 |
-
- `GEMINI_API_KEY`
|
| 77 |
-
- `CEREBRAS_API_KEY`
|
| 78 |
-
- `OPENROUTER_API_KEY`
|
| 79 |
-
- `HF_TOKEN`
|
| 80 |
-
|
| 81 |
-
## Repository Docs
|
| 82 |
-
|
| 83 |
-
- [.cursor/rules/dataforge.md](.cursor/rules/dataforge.md) — always-applied rules
|
| 84 |
-
- [ARCHITECTURE.md](ARCHITECTURE.md) — system diagram and dependency justification
|
| 85 |
-
- [DECISIONS.md](DECISIONS.md) — technical decision log
|
| 86 |
-
- [CONTRIBUTING.md](CONTRIBUTING.md) — workflow and code standards
|
| 87 |
-
- [CLAUDE.md](CLAUDE.md) — living knowledge base for Cursor sessions
|
| 88 |
-
- [CURSOR_MASTER.md](CURSOR_MASTER.md) — full context and prompt pack
|
| 89 |
-
- [META_CONTEXT.md](META_CONTEXT.md) — meta-context (read before writing code)
|
| 90 |
-
- [FILE_STRUCTURE.md](FILE_STRUCTURE.md) — canonical target directory tree
|
| 91 |
-
- [SECURITY.md](SECURITY.md) — vulnerability reporting policy
|
| 92 |
-
- [specs/SPEC_TEMPLATE.md](specs/SPEC_TEMPLATE.md) — spec template for new modules
|
| 93 |
-
|
| 94 |
-
## License
|
| 95 |
-
|
| 96 |
-
Apache-2.0. See [LICENSE](LICENSE).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataforge/__init__.py
CHANGED
|
@@ -1,5 +1,118 @@
|
|
| 1 |
-
"""DataForge public package.
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
__version__ = "0.1.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DataForge public package.
|
| 2 |
|
| 3 |
+
The root package is the stable facade for integration surfaces. Symbols are
|
| 4 |
+
resolved lazily so importing :mod:`dataforge` does not eagerly import pandas,
|
| 5 |
+
FastAPI-facing helpers, or the SMT stack.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from importlib import import_module
|
| 11 |
+
from typing import TYPE_CHECKING, Any
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from dataforge.cli.common import load_schema, read_csv, schema_from_mapping
|
| 15 |
+
from dataforge.detectors import Issue, Schema, Severity, run_all_detectors
|
| 16 |
+
from dataforge.engine.repair import (
|
| 17 |
+
CandidateFix,
|
| 18 |
+
RepairFailure,
|
| 19 |
+
RepairPipelineRequest,
|
| 20 |
+
RepairPipelineResult,
|
| 21 |
+
RepairReceipt,
|
| 22 |
+
VerifiedFix,
|
| 23 |
+
run_repair_pipeline,
|
| 24 |
+
)
|
| 25 |
+
from dataforge.repair_contract import CONTRACT_VERSION
|
| 26 |
+
from dataforge.repairers import ProposedFix
|
| 27 |
+
from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict
|
| 28 |
+
from dataforge.transactions.log import (
|
| 29 |
+
TransactionAuditReport,
|
| 30 |
+
TransactionAuditVerdict,
|
| 31 |
+
TransactionLogError,
|
| 32 |
+
verify_transaction_log,
|
| 33 |
+
)
|
| 34 |
+
from dataforge.transactions.revert import TransactionRevertError, revert_transaction
|
| 35 |
+
from dataforge.transactions.txn import CellFix, RepairTransaction
|
| 36 |
+
from dataforge.verifier import SMTVerifier, VerificationResult, VerificationVerdict
|
| 37 |
+
|
| 38 |
+
__all__ = [
|
| 39 |
+
"CONTRACT_VERSION",
|
| 40 |
+
"CandidateFix",
|
| 41 |
+
"CellFix",
|
| 42 |
+
"Issue",
|
| 43 |
+
"ProposedFix",
|
| 44 |
+
"RepairFailure",
|
| 45 |
+
"RepairPipelineRequest",
|
| 46 |
+
"RepairPipelineResult",
|
| 47 |
+
"RepairReceipt",
|
| 48 |
+
"RepairTransaction",
|
| 49 |
+
"SMTVerifier",
|
| 50 |
+
"SafetyContext",
|
| 51 |
+
"SafetyFilter",
|
| 52 |
+
"SafetyResult",
|
| 53 |
+
"SafetyVerdict",
|
| 54 |
+
"Schema",
|
| 55 |
+
"Severity",
|
| 56 |
+
"TransactionAuditReport",
|
| 57 |
+
"TransactionAuditVerdict",
|
| 58 |
+
"TransactionLogError",
|
| 59 |
+
"TransactionRevertError",
|
| 60 |
+
"VerificationResult",
|
| 61 |
+
"VerificationVerdict",
|
| 62 |
+
"VerifiedFix",
|
| 63 |
+
"__version__",
|
| 64 |
+
"load_schema",
|
| 65 |
+
"read_csv",
|
| 66 |
+
"revert_transaction",
|
| 67 |
+
"run_all_detectors",
|
| 68 |
+
"run_repair_pipeline",
|
| 69 |
+
"schema_from_mapping",
|
| 70 |
+
"verify_transaction_log",
|
| 71 |
+
]
|
| 72 |
|
| 73 |
__version__ = "0.1.0"
|
| 74 |
+
|
| 75 |
+
_PUBLIC_EXPORTS: dict[str, tuple[str, str]] = {
|
| 76 |
+
"CONTRACT_VERSION": ("dataforge.repair_contract", "CONTRACT_VERSION"),
|
| 77 |
+
"CandidateFix": ("dataforge.engine.repair", "CandidateFix"),
|
| 78 |
+
"CellFix": ("dataforge.transactions.txn", "CellFix"),
|
| 79 |
+
"Issue": ("dataforge.detectors", "Issue"),
|
| 80 |
+
"ProposedFix": ("dataforge.repairers", "ProposedFix"),
|
| 81 |
+
"RepairFailure": ("dataforge.engine.repair", "RepairFailure"),
|
| 82 |
+
"RepairPipelineRequest": ("dataforge.engine.repair", "RepairPipelineRequest"),
|
| 83 |
+
"RepairPipelineResult": ("dataforge.engine.repair", "RepairPipelineResult"),
|
| 84 |
+
"RepairReceipt": ("dataforge.engine.repair", "RepairReceipt"),
|
| 85 |
+
"RepairTransaction": ("dataforge.transactions.txn", "RepairTransaction"),
|
| 86 |
+
"SMTVerifier": ("dataforge.verifier", "SMTVerifier"),
|
| 87 |
+
"SafetyContext": ("dataforge.safety", "SafetyContext"),
|
| 88 |
+
"SafetyFilter": ("dataforge.safety", "SafetyFilter"),
|
| 89 |
+
"SafetyResult": ("dataforge.safety", "SafetyResult"),
|
| 90 |
+
"SafetyVerdict": ("dataforge.safety", "SafetyVerdict"),
|
| 91 |
+
"Schema": ("dataforge.detectors", "Schema"),
|
| 92 |
+
"Severity": ("dataforge.detectors", "Severity"),
|
| 93 |
+
"TransactionAuditReport": ("dataforge.transactions.log", "TransactionAuditReport"),
|
| 94 |
+
"TransactionAuditVerdict": ("dataforge.transactions.log", "TransactionAuditVerdict"),
|
| 95 |
+
"TransactionLogError": ("dataforge.transactions.log", "TransactionLogError"),
|
| 96 |
+
"TransactionRevertError": ("dataforge.transactions.revert", "TransactionRevertError"),
|
| 97 |
+
"VerificationResult": ("dataforge.verifier", "VerificationResult"),
|
| 98 |
+
"VerificationVerdict": ("dataforge.verifier", "VerificationVerdict"),
|
| 99 |
+
"VerifiedFix": ("dataforge.engine.repair", "VerifiedFix"),
|
| 100 |
+
"load_schema": ("dataforge.cli.common", "load_schema"),
|
| 101 |
+
"read_csv": ("dataforge.cli.common", "read_csv"),
|
| 102 |
+
"revert_transaction": ("dataforge.transactions.revert", "revert_transaction"),
|
| 103 |
+
"run_all_detectors": ("dataforge.detectors", "run_all_detectors"),
|
| 104 |
+
"run_repair_pipeline": ("dataforge.engine.repair", "run_repair_pipeline"),
|
| 105 |
+
"schema_from_mapping": ("dataforge.cli.common", "schema_from_mapping"),
|
| 106 |
+
"verify_transaction_log": ("dataforge.transactions.log", "verify_transaction_log"),
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def __getattr__(name: str) -> Any:
|
| 111 |
+
"""Resolve public facade exports on first use."""
|
| 112 |
+
try:
|
| 113 |
+
module_name, attribute_name = _PUBLIC_EXPORTS[name]
|
| 114 |
+
except KeyError as exc:
|
| 115 |
+
raise AttributeError(name) from exc
|
| 116 |
+
value = getattr(import_module(module_name), attribute_name)
|
| 117 |
+
globals()[name] = value
|
| 118 |
+
return value
|
dataforge/agent/__init__.py
CHANGED
|
@@ -1 +1,16 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DataForge agent package — typed tool-use actions and scratchpad.
|
| 2 |
+
|
| 3 |
+
Public API:
|
| 4 |
+
parse_action — Parse raw dict into typed Action model.
|
| 5 |
+
Action — Discriminated union of all action types.
|
| 6 |
+
Scratchpad — In-episode hypothesis tracker.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from dataforge.agent.scratchpad import Scratchpad
|
| 10 |
+
from dataforge.agent.tool_actions import Action, parse_action
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"Action",
|
| 14 |
+
"Scratchpad",
|
| 15 |
+
"parse_action",
|
| 16 |
+
]
|
dataforge/agent/providers.py
CHANGED
|
@@ -59,8 +59,9 @@ def get_provider_name() -> str:
|
|
| 59 |
"""Read the active provider from the environment.
|
| 60 |
|
| 61 |
Returns:
|
| 62 |
-
The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
Example:
|
| 66 |
>>> import os
|
|
@@ -68,7 +69,14 @@ def get_provider_name() -> str:
|
|
| 68 |
>>> get_provider_name()
|
| 69 |
'gemini'
|
| 70 |
"""
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
async def complete(
|
|
|
|
| 59 |
"""Read the active provider from the environment.
|
| 60 |
|
| 61 |
Returns:
|
| 62 |
+
The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``.
|
| 63 |
+
When no explicit provider is configured, prefer a provider whose
|
| 64 |
+
credential is present in the environment.
|
| 65 |
|
| 66 |
Example:
|
| 67 |
>>> import os
|
|
|
|
| 69 |
>>> get_provider_name()
|
| 70 |
'gemini'
|
| 71 |
"""
|
| 72 |
+
configured = os.environ.get("DATAFORGE_LLM_PROVIDER")
|
| 73 |
+
if configured:
|
| 74 |
+
return configured.lower()
|
| 75 |
+
if os.environ.get("GROQ_API_KEY"):
|
| 76 |
+
return "groq"
|
| 77 |
+
if os.environ.get("GEMINI_API_KEY"):
|
| 78 |
+
return "gemini"
|
| 79 |
+
return "groq"
|
| 80 |
|
| 81 |
|
| 82 |
async def complete(
|
dataforge/agent/scratchpad.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""In-episode hypothesis and issue tracker for the DataForge RL agent.
|
| 2 |
+
|
| 3 |
+
The scratchpad is a mutable, episode-scoped data structure that the agent
|
| 4 |
+
uses to record hypotheses, confirmed issues, and dead ends. The environment
|
| 5 |
+
exposes a compact summary of the scratchpad in each observation, enabling
|
| 6 |
+
the agent to reason about its investigation history without direct access
|
| 7 |
+
to the underlying data structure.
|
| 8 |
+
|
| 9 |
+
Example::
|
| 10 |
+
|
| 11 |
+
>>> from dataforge.agent.scratchpad import Scratchpad
|
| 12 |
+
>>> pad = Scratchpad()
|
| 13 |
+
>>> pad.add_hypothesis("Rating column has decimal shift", [5], ["rating"], "decimal_shift")
|
| 14 |
+
>>> pad.confirm_issue(5, "rating", "decimal_shift")
|
| 15 |
+
>>> pad.summary()
|
| 16 |
+
'Hypotheses: 1 (0 pending). Confirmed: 1. Dead ends: 0.'
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"ConfirmedIssue",
|
| 25 |
+
"DeadEnd",
|
| 26 |
+
"HypothesisRecord",
|
| 27 |
+
"Scratchpad",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass(frozen=True)
|
| 32 |
+
class HypothesisRecord:
|
| 33 |
+
"""A recorded hypothesis about a data-quality root cause.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
claim: Textual description of the hypothesis.
|
| 37 |
+
affected_rows: Row indices the hypothesis covers.
|
| 38 |
+
affected_columns: Column names the hypothesis covers.
|
| 39 |
+
root_cause_type: Detector-vocabulary root cause type.
|
| 40 |
+
confirmed: Whether the hypothesis was confirmed by ground truth.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
claim: str
|
| 44 |
+
affected_rows: tuple[int, ...]
|
| 45 |
+
affected_columns: tuple[str, ...]
|
| 46 |
+
root_cause_type: str
|
| 47 |
+
confirmed: bool = False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass(frozen=True)
|
| 51 |
+
class ConfirmedIssue:
|
| 52 |
+
"""A confirmed data-quality issue at a specific location.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
row: Zero-indexed row number.
|
| 56 |
+
column: Column name.
|
| 57 |
+
issue_type: Issue type classification.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
row: int
|
| 61 |
+
column: str
|
| 62 |
+
issue_type: str
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass(frozen=True)
|
| 66 |
+
class DeadEnd:
|
| 67 |
+
"""A recorded dead end — an investigation path that yielded nothing.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
description: What was tried and why it failed.
|
| 71 |
+
step_number: Step at which the dead end was recorded.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
description: str
|
| 75 |
+
step_number: int
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass
|
| 79 |
+
class Scratchpad:
|
| 80 |
+
"""Mutable in-episode tracker for hypotheses, confirmed issues, and dead ends.
|
| 81 |
+
|
| 82 |
+
Reset at the start of each episode. The ``summary()`` method produces a
|
| 83 |
+
compact string for inclusion in agent observations.
|
| 84 |
+
|
| 85 |
+
Example::
|
| 86 |
+
|
| 87 |
+
>>> pad = Scratchpad()
|
| 88 |
+
>>> pad.add_hypothesis("Decimal shift in rating", [5], ["rating"], "decimal_shift")
|
| 89 |
+
>>> len(pad.hypotheses)
|
| 90 |
+
1
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
hypotheses: list[HypothesisRecord] = field(default_factory=list)
|
| 94 |
+
confirmed_issues: list[ConfirmedIssue] = field(default_factory=list)
|
| 95 |
+
dead_ends: list[DeadEnd] = field(default_factory=list)
|
| 96 |
+
|
| 97 |
+
def add_hypothesis(
|
| 98 |
+
self,
|
| 99 |
+
claim: str,
|
| 100 |
+
affected_rows: list[int],
|
| 101 |
+
affected_columns: list[str],
|
| 102 |
+
root_cause_type: str,
|
| 103 |
+
) -> HypothesisRecord:
|
| 104 |
+
"""Record a new hypothesis.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
claim: Textual description of the hypothesis.
|
| 108 |
+
affected_rows: Row indices the hypothesis covers.
|
| 109 |
+
affected_columns: Column names the hypothesis covers.
|
| 110 |
+
root_cause_type: Detector-vocabulary root cause type.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
The recorded hypothesis.
|
| 114 |
+
"""
|
| 115 |
+
record = HypothesisRecord(
|
| 116 |
+
claim=claim,
|
| 117 |
+
affected_rows=tuple(affected_rows),
|
| 118 |
+
affected_columns=tuple(affected_columns),
|
| 119 |
+
root_cause_type=root_cause_type,
|
| 120 |
+
)
|
| 121 |
+
self.hypotheses.append(record)
|
| 122 |
+
return record
|
| 123 |
+
|
| 124 |
+
def confirm_hypothesis(self, index: int) -> None:
|
| 125 |
+
"""Mark a hypothesis as confirmed.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
index: Index into the ``hypotheses`` list.
|
| 129 |
+
|
| 130 |
+
Raises:
|
| 131 |
+
IndexError: If the index is out of range.
|
| 132 |
+
"""
|
| 133 |
+
old = self.hypotheses[index]
|
| 134 |
+
self.hypotheses[index] = HypothesisRecord(
|
| 135 |
+
claim=old.claim,
|
| 136 |
+
affected_rows=old.affected_rows,
|
| 137 |
+
affected_columns=old.affected_columns,
|
| 138 |
+
root_cause_type=old.root_cause_type,
|
| 139 |
+
confirmed=True,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def confirm_issue(self, row: int, column: str, issue_type: str) -> None:
|
| 143 |
+
"""Record a confirmed issue.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
row: Zero-indexed row number.
|
| 147 |
+
column: Column name.
|
| 148 |
+
issue_type: Issue type classification.
|
| 149 |
+
"""
|
| 150 |
+
self.confirmed_issues.append(ConfirmedIssue(row=row, column=column, issue_type=issue_type))
|
| 151 |
+
|
| 152 |
+
def add_dead_end(self, description: str, step_number: int) -> None:
|
| 153 |
+
"""Record a dead end.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
description: What was tried and why it failed.
|
| 157 |
+
step_number: Step at which the dead end was recorded.
|
| 158 |
+
"""
|
| 159 |
+
self.dead_ends.append(DeadEnd(description=description, step_number=step_number))
|
| 160 |
+
|
| 161 |
+
def reset(self) -> None:
|
| 162 |
+
"""Clear all tracked state for a new episode."""
|
| 163 |
+
self.hypotheses.clear()
|
| 164 |
+
self.confirmed_issues.clear()
|
| 165 |
+
self.dead_ends.clear()
|
| 166 |
+
|
| 167 |
+
def summary(self) -> str:
|
| 168 |
+
"""Produce a compact summary string for observation embedding.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
A one-line summary of scratchpad state.
|
| 172 |
+
|
| 173 |
+
Example::
|
| 174 |
+
|
| 175 |
+
>>> Scratchpad().summary()
|
| 176 |
+
'Hypotheses: 0 (0 pending). Confirmed: 0. Dead ends: 0.'
|
| 177 |
+
"""
|
| 178 |
+
pending = sum(1 for h in self.hypotheses if not h.confirmed)
|
| 179 |
+
return (
|
| 180 |
+
f"Hypotheses: {len(self.hypotheses)} ({pending} pending). "
|
| 181 |
+
f"Confirmed: {len(self.confirmed_issues)}. "
|
| 182 |
+
f"Dead ends: {len(self.dead_ends)}."
|
| 183 |
+
)
|
dataforge/agent/tool_actions.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Typed tool-use action models for the DataForge RL environment.
|
| 2 |
+
|
| 3 |
+
This module defines a discriminated union of 8 action types that an RL agent
|
| 4 |
+
can submit to the DataForge environment. Each action is a standalone Pydantic
|
| 5 |
+
model with its own validation rules, preventing cross-model field pollution.
|
| 6 |
+
|
| 7 |
+
The ``parse_action`` function is the single entry point for HTTP handlers
|
| 8 |
+
and tests to validate raw action dicts into typed models.
|
| 9 |
+
|
| 10 |
+
Action Types:
|
| 11 |
+
INSPECT_ROWS — View a slice of the dataset.
|
| 12 |
+
SQL_QUERY — Execute read-only SQL against the episode DataFrame.
|
| 13 |
+
STAT_TEST — Run a statistical test on a column.
|
| 14 |
+
PATTERN_MATCH — Evaluate a regex pattern against column values.
|
| 15 |
+
HYPOTHESIS — Record a causal-root claim for credit.
|
| 16 |
+
ROOT_CAUSE — Analyze selected detected errors for minimal roots.
|
| 17 |
+
DIAGNOSE — Flag a suspected issue at (row, column).
|
| 18 |
+
FIX — Propose a corrected value for a diagnosed issue.
|
| 19 |
+
|
| 20 |
+
Example::
|
| 21 |
+
|
| 22 |
+
>>> from dataforge.agent.tool_actions import parse_action
|
| 23 |
+
>>> action = parse_action({"action_type": "INSPECT_ROWS", "row_indices": [0, 1]})
|
| 24 |
+
>>> action.action_type
|
| 25 |
+
'INSPECT_ROWS'
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
from typing import Annotated, Any, Literal
|
| 31 |
+
|
| 32 |
+
from pydantic import BaseModel, Field, field_validator
|
| 33 |
+
|
| 34 |
+
__all__ = [
|
| 35 |
+
"Action",
|
| 36 |
+
"Diagnose",
|
| 37 |
+
"Fix",
|
| 38 |
+
"Hypothesis",
|
| 39 |
+
"InspectRows",
|
| 40 |
+
"PatternMatch",
|
| 41 |
+
"RootCause",
|
| 42 |
+
"SqlQuery",
|
| 43 |
+
"StatTest",
|
| 44 |
+
"parse_action",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class InspectRows(BaseModel):
|
| 49 |
+
"""View a slice of dataset rows.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
action_type: Must be ``"INSPECT_ROWS"``.
|
| 53 |
+
row_indices: Zero-indexed row indices to retrieve. At least 1 required.
|
| 54 |
+
column_names: Optional column filter. If omitted, all columns returned.
|
| 55 |
+
|
| 56 |
+
Example::
|
| 57 |
+
|
| 58 |
+
>>> InspectRows(action_type="INSPECT_ROWS", row_indices=[0, 1, 2])
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
action_type: Literal["INSPECT_ROWS"]
|
| 62 |
+
row_indices: list[int] = Field(min_length=1, description="Row indices to inspect (0-indexed).")
|
| 63 |
+
column_names: list[str] | None = Field(default=None, description="Optional column filter.")
|
| 64 |
+
|
| 65 |
+
@field_validator("row_indices")
|
| 66 |
+
@classmethod
|
| 67 |
+
def _validate_row_indices(cls, v: list[int]) -> list[int]:
|
| 68 |
+
"""Validate that all row indices are non-negative."""
|
| 69 |
+
if any(i < 0 for i in v):
|
| 70 |
+
raise ValueError("All row indices must be >= 0")
|
| 71 |
+
return v
|
| 72 |
+
|
| 73 |
+
model_config = {"frozen": True}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class SqlQuery(BaseModel):
|
| 77 |
+
"""Execute read-only SQL against the episode DataFrame via DuckDB.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
action_type: Must be ``"SQL_QUERY"``.
|
| 81 |
+
query: SQL query string. Must be read-only (SELECT only).
|
| 82 |
+
|
| 83 |
+
Example::
|
| 84 |
+
|
| 85 |
+
>>> SqlQuery(action_type="SQL_QUERY", query="SELECT * FROM data LIMIT 5")
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
action_type: Literal["SQL_QUERY"]
|
| 89 |
+
query: str = Field(min_length=1, description="Read-only SQL query.")
|
| 90 |
+
|
| 91 |
+
model_config = {"frozen": True}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class StatTest(BaseModel):
|
| 95 |
+
"""Run a statistical test on a dataset column.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
action_type: Must be ``"STAT_TEST"``.
|
| 99 |
+
test_type: One of ``"zscore"``, ``"iqr"``, ``"ks"``.
|
| 100 |
+
column: Column name to test.
|
| 101 |
+
threshold: Optional threshold override. Defaults vary by test type.
|
| 102 |
+
|
| 103 |
+
Example::
|
| 104 |
+
|
| 105 |
+
>>> StatTest(action_type="STAT_TEST", test_type="zscore", column="rating")
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
action_type: Literal["STAT_TEST"]
|
| 109 |
+
test_type: Literal["zscore", "iqr", "ks"] = Field(description="Statistical test to run.")
|
| 110 |
+
column: str = Field(min_length=1, description="Column name to test.")
|
| 111 |
+
threshold: float | None = Field(default=None, description="Optional threshold override.")
|
| 112 |
+
|
| 113 |
+
model_config = {"frozen": True}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class PatternMatch(BaseModel):
|
| 117 |
+
"""Evaluate a regex pattern against column values.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
action_type: Must be ``"PATTERN_MATCH"``.
|
| 121 |
+
pattern: Regular expression string.
|
| 122 |
+
column: Column name to evaluate.
|
| 123 |
+
expect_match: If True, report rows that match. If False, report non-matches.
|
| 124 |
+
|
| 125 |
+
Example::
|
| 126 |
+
|
| 127 |
+
>>> PatternMatch(
|
| 128 |
+
... action_type="PATTERN_MATCH",
|
| 129 |
+
... pattern=r"^\\d{5}$",
|
| 130 |
+
... column="zip_code",
|
| 131 |
+
... )
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
action_type: Literal["PATTERN_MATCH"]
|
| 135 |
+
pattern: str = Field(min_length=1, description="Regex pattern.")
|
| 136 |
+
column: str = Field(min_length=1, description="Column name to evaluate.")
|
| 137 |
+
expect_match: bool = Field(
|
| 138 |
+
default=True,
|
| 139 |
+
description="True to report matches, False to report non-matches.",
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
model_config = {"frozen": True}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Hypothesis(BaseModel):
|
| 146 |
+
"""Record a causal-root claim for root-cause credit.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
action_type: Must be ``"HYPOTHESIS"``.
|
| 150 |
+
claim: Textual description of the hypothesized root cause.
|
| 151 |
+
affected_rows: Row indices believed to be affected.
|
| 152 |
+
affected_columns: Column names believed to be affected.
|
| 153 |
+
root_cause_type: Detector-vocabulary root cause type
|
| 154 |
+
(e.g., ``"decimal_shift"``, ``"type_mismatch"``).
|
| 155 |
+
|
| 156 |
+
Example::
|
| 157 |
+
|
| 158 |
+
>>> Hypothesis(
|
| 159 |
+
... action_type="HYPOTHESIS",
|
| 160 |
+
... claim="Column 'rating' has a decimal shift at row 5",
|
| 161 |
+
... affected_rows=[5],
|
| 162 |
+
... affected_columns=["rating"],
|
| 163 |
+
... root_cause_type="decimal_shift",
|
| 164 |
+
... )
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
action_type: Literal["HYPOTHESIS"]
|
| 168 |
+
claim: str = Field(min_length=1, description="Root-cause claim.")
|
| 169 |
+
affected_rows: list[int] = Field(min_length=1, description="Affected row indices.")
|
| 170 |
+
affected_columns: list[str] = Field(min_length=1, description="Affected column names.")
|
| 171 |
+
root_cause_type: str = Field(min_length=1, description="Detector-vocabulary root cause type.")
|
| 172 |
+
|
| 173 |
+
@field_validator("affected_rows")
|
| 174 |
+
@classmethod
|
| 175 |
+
def _validate_affected_rows(cls, v: list[int]) -> list[int]:
|
| 176 |
+
"""Validate that all affected row indices are non-negative."""
|
| 177 |
+
if any(i < 0 for i in v):
|
| 178 |
+
raise ValueError("All affected row indices must be >= 0")
|
| 179 |
+
return v
|
| 180 |
+
|
| 181 |
+
model_config = {"frozen": True}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class RootCause(BaseModel):
|
| 185 |
+
"""Analyze selected detected errors for minimal causal roots.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
action_type: Must be ``"ROOT_CAUSE"``.
|
| 189 |
+
error_indices: Zero-based indices into the episode's detected issue list.
|
| 190 |
+
|
| 191 |
+
Example::
|
| 192 |
+
|
| 193 |
+
>>> RootCause(action_type="ROOT_CAUSE", error_indices=[0, 1])
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
action_type: Literal["ROOT_CAUSE"]
|
| 197 |
+
error_indices: list[int] = Field(min_length=1, description="Detected issue indices.")
|
| 198 |
+
|
| 199 |
+
@field_validator("error_indices")
|
| 200 |
+
@classmethod
|
| 201 |
+
def _validate_error_indices(cls, v: list[int]) -> list[int]:
|
| 202 |
+
"""Validate that all error indices are non-negative."""
|
| 203 |
+
if any(i < 0 for i in v):
|
| 204 |
+
raise ValueError("All error indices must be >= 0")
|
| 205 |
+
return v
|
| 206 |
+
|
| 207 |
+
model_config = {"frozen": True}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class Diagnose(BaseModel):
|
| 211 |
+
"""Flag a suspected data-quality issue at a specific (row, column).
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
action_type: Must be ``"DIAGNOSE"``.
|
| 215 |
+
row: Zero-indexed row number.
|
| 216 |
+
column: Column name.
|
| 217 |
+
issue_type: Issue type from detector vocabulary.
|
| 218 |
+
|
| 219 |
+
Example::
|
| 220 |
+
|
| 221 |
+
>>> Diagnose(
|
| 222 |
+
... action_type="DIAGNOSE",
|
| 223 |
+
... row=5, column="rating",
|
| 224 |
+
... issue_type="decimal_shift",
|
| 225 |
+
... )
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
action_type: Literal["DIAGNOSE"]
|
| 229 |
+
row: int = Field(ge=0, description="Zero-indexed row number.")
|
| 230 |
+
column: str = Field(min_length=1, description="Column name.")
|
| 231 |
+
issue_type: str = Field(min_length=1, description="Issue type classification.")
|
| 232 |
+
|
| 233 |
+
model_config = {"frozen": True}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class Fix(BaseModel):
|
| 237 |
+
"""Propose a corrected value for a diagnosed issue.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
action_type: Must be ``"FIX"``.
|
| 241 |
+
row: Zero-indexed row number.
|
| 242 |
+
column: Column name.
|
| 243 |
+
new_value: The corrected cell value as a string.
|
| 244 |
+
justification: Explanation of why this fix is correct.
|
| 245 |
+
fix_type: How to fix the issue. Defaults to ``"correct_value"``.
|
| 246 |
+
|
| 247 |
+
Example::
|
| 248 |
+
|
| 249 |
+
>>> Fix(
|
| 250 |
+
... action_type="FIX",
|
| 251 |
+
... row=5, column="rating",
|
| 252 |
+
... new_value="4.5",
|
| 253 |
+
... justification="Decimal shift: 45.0 should be 4.5",
|
| 254 |
+
... )
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
action_type: Literal["FIX"]
|
| 258 |
+
row: int = Field(ge=0, description="Zero-indexed row number.")
|
| 259 |
+
column: str = Field(min_length=1, description="Column name.")
|
| 260 |
+
new_value: str = Field(description="Corrected cell value.")
|
| 261 |
+
justification: str = Field(min_length=1, description="Fix justification.")
|
| 262 |
+
fix_type: Literal["correct_value", "delete_row", "impute", "standardize"] = Field(
|
| 263 |
+
default="correct_value", description="Fix operation type."
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
model_config = {"frozen": True}
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 270 |
+
# Discriminated union and parser
|
| 271 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 272 |
+
|
| 273 |
+
Action = Annotated[
|
| 274 |
+
InspectRows | SqlQuery | StatTest | PatternMatch | Hypothesis | RootCause | Diagnose | Fix,
|
| 275 |
+
Field(discriminator="action_type"),
|
| 276 |
+
]
|
| 277 |
+
"""Discriminated union of all valid DataForge environment actions."""
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def parse_action(raw: dict[str, Any]) -> Action:
|
| 281 |
+
"""Parse and validate a raw action dict into the appropriate typed model.
|
| 282 |
+
|
| 283 |
+
This is the single entry point for HTTP handlers and tests to validate
|
| 284 |
+
actions. The ``action_type`` field is used as the discriminator.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
raw: Dictionary with an ``action_type`` key and action-specific fields.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
A validated action model instance.
|
| 291 |
+
|
| 292 |
+
Raises:
|
| 293 |
+
pydantic.ValidationError: If the action is malformed or invalid.
|
| 294 |
+
KeyError: If ``action_type`` is missing.
|
| 295 |
+
ValueError: If ``action_type`` is not recognized.
|
| 296 |
+
|
| 297 |
+
Example::
|
| 298 |
+
|
| 299 |
+
>>> action = parse_action({"action_type": "INSPECT_ROWS", "row_indices": [0]})
|
| 300 |
+
>>> isinstance(action, InspectRows)
|
| 301 |
+
True
|
| 302 |
+
"""
|
| 303 |
+
from pydantic import TypeAdapter
|
| 304 |
+
|
| 305 |
+
adapter: TypeAdapter[Action] = TypeAdapter(Action)
|
| 306 |
+
return adapter.validate_python(_normalize_action(raw))
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _normalize_action(raw: dict[str, Any]) -> dict[str, Any]:
|
| 310 |
+
"""Return a canonical action dictionary from supported external aliases."""
|
| 311 |
+
normalized = dict(raw)
|
| 312 |
+
action_type = normalized.get("action_type")
|
| 313 |
+
if action_type == "SQL_QUERY" and "sql" in normalized and "query" not in normalized:
|
| 314 |
+
normalized["query"] = normalized["sql"]
|
| 315 |
+
if action_type == "STAT_TEST" and "test" in normalized and "test_type" not in normalized:
|
| 316 |
+
normalized["test_type"] = normalized["test"]
|
| 317 |
+
if action_type == "PATTERN_MATCH":
|
| 318 |
+
if "regex" in normalized and "pattern" not in normalized:
|
| 319 |
+
normalized["pattern"] = normalized["regex"]
|
| 320 |
+
if "expect" in normalized and "expect_match" not in normalized:
|
| 321 |
+
normalized["expect_match"] = normalized["expect"] == "match"
|
| 322 |
+
if action_type == "HYPOTHESIS":
|
| 323 |
+
root_column = normalized.get("root_column")
|
| 324 |
+
downstream = normalized.get("downstream")
|
| 325 |
+
if root_column is not None and "affected_columns" not in normalized:
|
| 326 |
+
downstream_columns = downstream if isinstance(downstream, list) else []
|
| 327 |
+
normalized["affected_columns"] = [root_column, *downstream_columns]
|
| 328 |
+
if "affected_rows" not in normalized:
|
| 329 |
+
normalized["affected_rows"] = [0]
|
| 330 |
+
if root_column is not None and "root_cause_type" not in normalized:
|
| 331 |
+
normalized["root_cause_type"] = root_column
|
| 332 |
+
if (
|
| 333 |
+
action_type == "ROOT_CAUSE"
|
| 334 |
+
and "indices" in normalized
|
| 335 |
+
and "error_indices" not in normalized
|
| 336 |
+
):
|
| 337 |
+
normalized["error_indices"] = normalized["indices"]
|
| 338 |
+
if action_type == "FIX":
|
| 339 |
+
if "proposed_value" in normalized and "new_value" not in normalized:
|
| 340 |
+
normalized["new_value"] = normalized["proposed_value"]
|
| 341 |
+
if "justification" not in normalized:
|
| 342 |
+
normalized["justification"] = "Agent proposed value via FIX."
|
| 343 |
+
return normalized
|
dataforge/bench/core.py
CHANGED
|
@@ -59,6 +59,7 @@ class SeedBenchmarkResult(BaseModel):
|
|
| 59 |
prompt_tokens: int = Field(ge=0, default=0)
|
| 60 |
completion_tokens: int = Field(ge=0, default=0)
|
| 61 |
quota_units: float = Field(ge=0.0, default=0.0)
|
|
|
|
| 62 |
runtime_s: float = Field(ge=0.0, default=0.0)
|
| 63 |
provider: str | None = None
|
| 64 |
model: str | None = None
|
|
@@ -85,6 +86,8 @@ class AggregateBenchmarkResult(BaseModel):
|
|
| 85 |
avg_steps_std: float | None = None
|
| 86 |
quota_units_mean: float | None = None
|
| 87 |
quota_units_std: float | None = None
|
|
|
|
|
|
|
| 88 |
runtime_s_mean: float | None = None
|
| 89 |
runtime_s_std: float | None = None
|
| 90 |
provider: str | None = None
|
|
@@ -229,6 +232,7 @@ def aggregate_seed_results(
|
|
| 229 |
f1_mean, f1_std = _mean_std([row.f1 or 0.0 for row in ok_rows])
|
| 230 |
avg_steps_mean, avg_steps_std = _mean_std([row.avg_steps or 0.0 for row in ok_rows])
|
| 231 |
quota_mean, quota_std = _mean_std([row.quota_units for row in ok_rows])
|
|
|
|
| 232 |
runtime_mean, runtime_std = _mean_std([row.runtime_s for row in ok_rows])
|
| 233 |
aggregates.append(
|
| 234 |
AggregateBenchmarkResult(
|
|
@@ -248,6 +252,8 @@ def aggregate_seed_results(
|
|
| 248 |
avg_steps_std=avg_steps_std,
|
| 249 |
quota_units_mean=quota_mean,
|
| 250 |
quota_units_std=quota_std,
|
|
|
|
|
|
|
| 251 |
runtime_s_mean=runtime_mean,
|
| 252 |
runtime_s_std=runtime_std,
|
| 253 |
provider=ok_rows[0].provider,
|
|
|
|
| 59 |
prompt_tokens: int = Field(ge=0, default=0)
|
| 60 |
completion_tokens: int = Field(ge=0, default=0)
|
| 61 |
quota_units: float = Field(ge=0.0, default=0.0)
|
| 62 |
+
gpu_hours: float = Field(ge=0.0, default=0.0)
|
| 63 |
runtime_s: float = Field(ge=0.0, default=0.0)
|
| 64 |
provider: str | None = None
|
| 65 |
model: str | None = None
|
|
|
|
| 86 |
avg_steps_std: float | None = None
|
| 87 |
quota_units_mean: float | None = None
|
| 88 |
quota_units_std: float | None = None
|
| 89 |
+
gpu_hours_mean: float | None = None
|
| 90 |
+
gpu_hours_std: float | None = None
|
| 91 |
runtime_s_mean: float | None = None
|
| 92 |
runtime_s_std: float | None = None
|
| 93 |
provider: str | None = None
|
|
|
|
| 232 |
f1_mean, f1_std = _mean_std([row.f1 or 0.0 for row in ok_rows])
|
| 233 |
avg_steps_mean, avg_steps_std = _mean_std([row.avg_steps or 0.0 for row in ok_rows])
|
| 234 |
quota_mean, quota_std = _mean_std([row.quota_units for row in ok_rows])
|
| 235 |
+
gpu_hours_mean, gpu_hours_std = _mean_std([row.gpu_hours for row in ok_rows])
|
| 236 |
runtime_mean, runtime_std = _mean_std([row.runtime_s for row in ok_rows])
|
| 237 |
aggregates.append(
|
| 238 |
AggregateBenchmarkResult(
|
|
|
|
| 252 |
avg_steps_std=avg_steps_std,
|
| 253 |
quota_units_mean=quota_mean,
|
| 254 |
quota_units_std=quota_std,
|
| 255 |
+
gpu_hours_mean=gpu_hours_mean,
|
| 256 |
+
gpu_hours_std=gpu_hours_std,
|
| 257 |
runtime_s_mean=runtime_mean,
|
| 258 |
runtime_s_std=runtime_std,
|
| 259 |
provider=ok_rows[0].provider,
|
dataforge/bench/groq_client.py
CHANGED
|
@@ -1,21 +1,45 @@
|
|
| 1 |
-
"""Minimal
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
|
|
|
| 6 |
import time
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from typing import cast
|
| 9 |
|
| 10 |
import httpx
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def _is_rate_limit_error(exc: BaseException) -> bool:
|
| 15 |
-
"""Return whether an exception is
|
| 16 |
return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
@dataclass(frozen=True, kw_only=True)
|
| 20 |
class GroqCompletion:
|
| 21 |
"""Completion payload plus conservative usage accounting."""
|
|
@@ -26,26 +50,50 @@ class GroqCompletion:
|
|
| 26 |
warnings: tuple[str, ...]
|
| 27 |
|
| 28 |
|
| 29 |
-
class
|
| 30 |
-
"""Sequential
|
| 31 |
|
| 32 |
def __init__(
|
| 33 |
self,
|
| 34 |
*,
|
| 35 |
api_key: str,
|
| 36 |
-
model: str
|
|
|
|
|
|
|
| 37 |
min_interval_s: float = 2.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
) -> None:
|
| 39 |
self._api_key = api_key
|
| 40 |
self._model = model
|
|
|
|
|
|
|
| 41 |
self._min_interval_s = min_interval_s
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
self._last_success_at: float | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
@property
|
| 45 |
def model(self) -> str:
|
| 46 |
-
"""Return the configured
|
| 47 |
return self._model
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def _respect_spacing(self) -> None:
|
| 50 |
"""Sleep long enough to keep requests sequential with a fixed gap."""
|
| 51 |
if self._last_success_at is None:
|
|
@@ -55,33 +103,57 @@ class GroqBenchClient:
|
|
| 55 |
if remaining > 0:
|
| 56 |
time.sleep(remaining)
|
| 57 |
|
| 58 |
-
@retry(
|
| 59 |
-
retry=retry_if_exception(_is_rate_limit_error),
|
| 60 |
-
wait=wait_fixed(2),
|
| 61 |
-
stop=stop_after_attempt(3),
|
| 62 |
-
reraise=True,
|
| 63 |
-
)
|
| 64 |
def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
|
| 65 |
-
"""Issue the underlying
|
| 66 |
payload = {
|
| 67 |
"model": self._model,
|
| 68 |
"messages": messages,
|
| 69 |
"temperature": 0.0,
|
|
|
|
| 70 |
}
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
|
| 84 |
-
"""Send one benchmark completion request to
|
| 85 |
self._respect_spacing()
|
| 86 |
payload = self._post(messages)
|
| 87 |
self._last_success_at = time.monotonic()
|
|
@@ -92,16 +164,223 @@ class GroqBenchClient:
|
|
| 92 |
completion_tokens = int(usage.get("completion_tokens", 0)) if isinstance(usage, dict) else 0
|
| 93 |
if not usage:
|
| 94 |
warnings.append("missing_usage_payload")
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
try:
|
| 97 |
choices = cast(list[dict[str, object]], payload["choices"])
|
| 98 |
message = cast(dict[str, object], choices[0]["message"])
|
| 99 |
content = str(message["content"])
|
| 100 |
except (KeyError, IndexError, TypeError) as exc:
|
| 101 |
-
raise ValueError(
|
|
|
|
|
|
|
| 102 |
return GroqCompletion(
|
| 103 |
text=content,
|
| 104 |
prompt_tokens=prompt_tokens,
|
| 105 |
completion_tokens=completion_tokens,
|
| 106 |
warnings=tuple(warnings),
|
| 107 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal OpenAI-compatible clients for benchmark-only LLM baselines."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import json
|
| 6 |
+
import logging
|
| 7 |
import time
|
| 8 |
from dataclasses import dataclass
|
| 9 |
from typing import cast
|
| 10 |
|
| 11 |
import httpx
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ProviderRequestError(RuntimeError):
|
| 15 |
+
"""Raised when a provider rejects a benchmark request payload."""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ProviderRateLimitError(ProviderRequestError):
|
| 19 |
+
"""Raised when a provider asks us to wait longer than the configured cap."""
|
| 20 |
|
| 21 |
|
| 22 |
def _is_rate_limit_error(exc: BaseException) -> bool:
|
| 23 |
+
"""Return whether an exception is an HTTP 429 response."""
|
| 24 |
return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429
|
| 25 |
|
| 26 |
|
| 27 |
+
def _is_retryable_provider_error(exc: BaseException) -> bool:
|
| 28 |
+
"""Return whether an HTTP error is worth retrying for teacher collection."""
|
| 29 |
+
return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code in {429, 503}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _retry_after_s(exc: httpx.HTTPStatusError, *, fallback_s: float) -> float:
|
| 33 |
+
"""Return provider retry-after delay when present."""
|
| 34 |
+
raw_retry_after = exc.response.headers.get("retry-after")
|
| 35 |
+
if raw_retry_after is None:
|
| 36 |
+
return fallback_s
|
| 37 |
+
try:
|
| 38 |
+
return max(float(raw_retry_after), fallback_s)
|
| 39 |
+
except ValueError:
|
| 40 |
+
return fallback_s
|
| 41 |
+
|
| 42 |
+
|
| 43 |
@dataclass(frozen=True, kw_only=True)
|
| 44 |
class GroqCompletion:
|
| 45 |
"""Completion payload plus conservative usage accounting."""
|
|
|
|
| 50 |
warnings: tuple[str, ...]
|
| 51 |
|
| 52 |
|
| 53 |
+
class OpenAICompatBenchClient:
|
| 54 |
+
"""Sequential OpenAI-compatible client with fixed 429 retry and spacing."""
|
| 55 |
|
| 56 |
def __init__(
|
| 57 |
self,
|
| 58 |
*,
|
| 59 |
api_key: str,
|
| 60 |
+
model: str,
|
| 61 |
+
endpoint: str,
|
| 62 |
+
provider: str,
|
| 63 |
min_interval_s: float = 2.0,
|
| 64 |
+
max_tokens: int = 512,
|
| 65 |
+
max_retries: int = 5,
|
| 66 |
+
max_retry_after_s: float = 120.0,
|
| 67 |
+
timeout_s: float = 60.0,
|
| 68 |
) -> None:
|
| 69 |
self._api_key = api_key
|
| 70 |
self._model = model
|
| 71 |
+
self._endpoint = endpoint
|
| 72 |
+
self._provider = provider
|
| 73 |
self._min_interval_s = min_interval_s
|
| 74 |
+
self._max_tokens = max_tokens
|
| 75 |
+
self._max_retries = max_retries
|
| 76 |
+
self._max_retry_after_s = max_retry_after_s
|
| 77 |
+
self._timeout_s = timeout_s
|
| 78 |
self._last_success_at: float | None = None
|
| 79 |
+
self._client = httpx.Client(
|
| 80 |
+
timeout=self._timeout_s,
|
| 81 |
+
headers={
|
| 82 |
+
"Authorization": f"Bearer {self._api_key}",
|
| 83 |
+
"Content-Type": "application/json",
|
| 84 |
+
},
|
| 85 |
+
)
|
| 86 |
|
| 87 |
@property
|
| 88 |
def model(self) -> str:
|
| 89 |
+
"""Return the configured provider model name."""
|
| 90 |
return self._model
|
| 91 |
|
| 92 |
+
@property
|
| 93 |
+
def provider(self) -> str:
|
| 94 |
+
"""Return the configured provider identifier."""
|
| 95 |
+
return self._provider
|
| 96 |
+
|
| 97 |
def _respect_spacing(self) -> None:
|
| 98 |
"""Sleep long enough to keep requests sequential with a fixed gap."""
|
| 99 |
if self._last_success_at is None:
|
|
|
|
| 103 |
if remaining > 0:
|
| 104 |
time.sleep(remaining)
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
|
| 107 |
+
"""Issue the underlying chat-completions request."""
|
| 108 |
payload = {
|
| 109 |
"model": self._model,
|
| 110 |
"messages": messages,
|
| 111 |
"temperature": 0.0,
|
| 112 |
+
"max_tokens": self._max_tokens,
|
| 113 |
}
|
| 114 |
+
last_rate_limit_error: httpx.HTTPStatusError | None = None
|
| 115 |
+
for attempt in range(self._max_retries):
|
| 116 |
+
response: httpx.Response | None = None
|
| 117 |
+
try:
|
| 118 |
+
response = self._client.post(
|
| 119 |
+
self._endpoint,
|
| 120 |
+
json=payload,
|
| 121 |
+
)
|
| 122 |
+
response.raise_for_status()
|
| 123 |
+
except httpx.HTTPStatusError as exc:
|
| 124 |
+
if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1:
|
| 125 |
+
body = exc.response.text[:500].replace("\n", " ")
|
| 126 |
+
raise ProviderRequestError(
|
| 127 |
+
f"{self._provider} request rejected with HTTP "
|
| 128 |
+
f"{exc.response.status_code}: {body}"
|
| 129 |
+
) from exc
|
| 130 |
+
last_rate_limit_error = exc
|
| 131 |
+
retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1))
|
| 132 |
+
if retry_s > self._max_retry_after_s:
|
| 133 |
+
body = exc.response.text[:500].replace("\n", " ")
|
| 134 |
+
raise ProviderRateLimitError(
|
| 135 |
+
f"{self._provider} rate limit retry-after {retry_s:.2f}s "
|
| 136 |
+
f"exceeds cap {self._max_retry_after_s:.2f}s: {body}"
|
| 137 |
+
) from exc
|
| 138 |
+
logging.getLogger("dataforge.bench.groq_client").warning(
|
| 139 |
+
"%s_rate_limit attempt=%d retry_after_s=%.2f",
|
| 140 |
+
self._provider,
|
| 141 |
+
attempt + 1,
|
| 142 |
+
retry_s,
|
| 143 |
+
)
|
| 144 |
+
time.sleep(retry_s)
|
| 145 |
+
continue
|
| 146 |
+
except httpx.TimeoutException as exc:
|
| 147 |
+
raise TimeoutError(
|
| 148 |
+
f"{self._provider} request timed out after {self._timeout_s:.1f} seconds."
|
| 149 |
+
) from exc
|
| 150 |
+
return dict(response.json())
|
| 151 |
+
if last_rate_limit_error is not None:
|
| 152 |
+
raise last_rate_limit_error
|
| 153 |
+
raise RuntimeError(f"{self._provider} request failed without a response.")
|
| 154 |
|
| 155 |
def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
|
| 156 |
+
"""Send one benchmark completion request to the configured provider."""
|
| 157 |
self._respect_spacing()
|
| 158 |
payload = self._post(messages)
|
| 159 |
self._last_success_at = time.monotonic()
|
|
|
|
| 164 |
completion_tokens = int(usage.get("completion_tokens", 0)) if isinstance(usage, dict) else 0
|
| 165 |
if not usage:
|
| 166 |
warnings.append("missing_usage_payload")
|
| 167 |
+
logging.getLogger("dataforge.bench.groq_client").warning(
|
| 168 |
+
"%s_missing_usage_payload", self._provider
|
| 169 |
+
)
|
| 170 |
|
| 171 |
try:
|
| 172 |
choices = cast(list[dict[str, object]], payload["choices"])
|
| 173 |
message = cast(dict[str, object], choices[0]["message"])
|
| 174 |
content = str(message["content"])
|
| 175 |
except (KeyError, IndexError, TypeError) as exc:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"Unexpected {self._provider} response payload: {json.dumps(payload)}"
|
| 178 |
+
) from exc
|
| 179 |
return GroqCompletion(
|
| 180 |
text=content,
|
| 181 |
prompt_tokens=prompt_tokens,
|
| 182 |
completion_tokens=completion_tokens,
|
| 183 |
warnings=tuple(warnings),
|
| 184 |
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class GroqBenchClient(OpenAICompatBenchClient):
|
| 188 |
+
"""Sequential Groq client with fixed 429 retry and spacing."""
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
*,
|
| 193 |
+
api_key: str,
|
| 194 |
+
model: str = "llama-3.3-70b-versatile",
|
| 195 |
+
min_interval_s: float = 2.0,
|
| 196 |
+
max_tokens: int = 512,
|
| 197 |
+
max_retries: int = 5,
|
| 198 |
+
max_retry_after_s: float = 120.0,
|
| 199 |
+
timeout_s: float = 60.0,
|
| 200 |
+
) -> None:
|
| 201 |
+
super().__init__(
|
| 202 |
+
api_key=api_key,
|
| 203 |
+
model=model,
|
| 204 |
+
endpoint="https://api.groq.com/openai/v1/chat/completions",
|
| 205 |
+
provider="groq",
|
| 206 |
+
min_interval_s=min_interval_s,
|
| 207 |
+
max_tokens=max_tokens,
|
| 208 |
+
max_retries=max_retries,
|
| 209 |
+
max_retry_after_s=max_retry_after_s,
|
| 210 |
+
timeout_s=timeout_s,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class CerebrasBenchClient(OpenAICompatBenchClient):
|
| 215 |
+
"""Sequential Cerebras client with fixed 429 retry and spacing."""
|
| 216 |
+
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
*,
|
| 220 |
+
api_key: str,
|
| 221 |
+
model: str = "qwen-3-235b-a22b-instruct-2507",
|
| 222 |
+
min_interval_s: float = 0.5,
|
| 223 |
+
max_tokens: int = 512,
|
| 224 |
+
max_retries: int = 5,
|
| 225 |
+
max_retry_after_s: float = 120.0,
|
| 226 |
+
timeout_s: float = 60.0,
|
| 227 |
+
) -> None:
|
| 228 |
+
super().__init__(
|
| 229 |
+
api_key=api_key,
|
| 230 |
+
model=model,
|
| 231 |
+
endpoint="https://api.cerebras.ai/v1/chat/completions",
|
| 232 |
+
provider="cerebras",
|
| 233 |
+
min_interval_s=min_interval_s,
|
| 234 |
+
max_tokens=max_tokens,
|
| 235 |
+
max_retries=max_retries,
|
| 236 |
+
max_retry_after_s=max_retry_after_s,
|
| 237 |
+
timeout_s=timeout_s,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class GeminiBenchClient:
|
| 242 |
+
"""Sequential Gemini client adapted to the benchmark completion interface."""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
*,
|
| 247 |
+
api_key: str,
|
| 248 |
+
model: str = "gemini-3.1-pro-preview",
|
| 249 |
+
min_interval_s: float = 2.0,
|
| 250 |
+
max_tokens: int = 512,
|
| 251 |
+
max_retries: int = 5,
|
| 252 |
+
max_retry_after_s: float = 120.0,
|
| 253 |
+
timeout_s: float = 60.0,
|
| 254 |
+
) -> None:
|
| 255 |
+
self._api_key = api_key
|
| 256 |
+
self._model = model.removeprefix("models/")
|
| 257 |
+
self._min_interval_s = min_interval_s
|
| 258 |
+
self._max_tokens = max_tokens
|
| 259 |
+
self._max_retries = max_retries
|
| 260 |
+
self._max_retry_after_s = max_retry_after_s
|
| 261 |
+
self._timeout_s = timeout_s
|
| 262 |
+
self._last_success_at: float | None = None
|
| 263 |
+
self._client = httpx.Client(
|
| 264 |
+
timeout=self._timeout_s,
|
| 265 |
+
headers={"Content-Type": "application/json"},
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
@property
|
| 269 |
+
def model(self) -> str:
|
| 270 |
+
"""Return the configured Gemini model name."""
|
| 271 |
+
return self._model
|
| 272 |
+
|
| 273 |
+
@property
|
| 274 |
+
def provider(self) -> str:
|
| 275 |
+
"""Return the provider identifier."""
|
| 276 |
+
return "gemini"
|
| 277 |
+
|
| 278 |
+
def _respect_spacing(self) -> None:
|
| 279 |
+
"""Sleep long enough to keep requests sequential with a fixed gap."""
|
| 280 |
+
if self._last_success_at is None:
|
| 281 |
+
return
|
| 282 |
+
elapsed = time.monotonic() - self._last_success_at
|
| 283 |
+
remaining = self._min_interval_s - elapsed
|
| 284 |
+
if remaining > 0:
|
| 285 |
+
time.sleep(remaining)
|
| 286 |
+
|
| 287 |
+
def _payload(self, messages: list[dict[str, str]]) -> dict[str, object]:
|
| 288 |
+
"""Convert OpenAI-style chat messages to Gemini generateContent payload."""
|
| 289 |
+
system_texts: list[str] = []
|
| 290 |
+
contents: list[dict[str, object]] = []
|
| 291 |
+
for message in messages:
|
| 292 |
+
role = message.get("role", "user")
|
| 293 |
+
content = message.get("content", "")
|
| 294 |
+
if role == "system":
|
| 295 |
+
system_texts.append(content)
|
| 296 |
+
continue
|
| 297 |
+
gemini_role = "model" if role == "assistant" else "user"
|
| 298 |
+
contents.append({"role": gemini_role, "parts": [{"text": content}]})
|
| 299 |
+
|
| 300 |
+
payload: dict[str, object] = {
|
| 301 |
+
"contents": contents,
|
| 302 |
+
"generationConfig": {
|
| 303 |
+
"temperature": 0.0,
|
| 304 |
+
"maxOutputTokens": self._max_tokens,
|
| 305 |
+
},
|
| 306 |
+
}
|
| 307 |
+
if system_texts:
|
| 308 |
+
payload["systemInstruction"] = {
|
| 309 |
+
"parts": [{"text": "\n\n".join(system_texts)}],
|
| 310 |
+
}
|
| 311 |
+
return payload
|
| 312 |
+
|
| 313 |
+
def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
|
| 314 |
+
"""Issue the underlying Gemini generateContent request."""
|
| 315 |
+
endpoint = (
|
| 316 |
+
f"https://generativelanguage.googleapis.com/v1beta/models/{self._model}:generateContent"
|
| 317 |
+
)
|
| 318 |
+
last_rate_limit_error: httpx.HTTPStatusError | None = None
|
| 319 |
+
for attempt in range(self._max_retries):
|
| 320 |
+
response: httpx.Response | None = None
|
| 321 |
+
try:
|
| 322 |
+
response = self._client.post(
|
| 323 |
+
endpoint,
|
| 324 |
+
params={"key": self._api_key},
|
| 325 |
+
json=self._payload(messages),
|
| 326 |
+
)
|
| 327 |
+
response.raise_for_status()
|
| 328 |
+
except httpx.HTTPStatusError as exc:
|
| 329 |
+
if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1:
|
| 330 |
+
body = exc.response.text[:500].replace("\n", " ")
|
| 331 |
+
raise ProviderRequestError(
|
| 332 |
+
f"gemini request rejected with HTTP {exc.response.status_code}: {body}"
|
| 333 |
+
) from exc
|
| 334 |
+
last_rate_limit_error = exc
|
| 335 |
+
retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1))
|
| 336 |
+
if retry_s > self._max_retry_after_s:
|
| 337 |
+
body = exc.response.text[:500].replace("\n", " ")
|
| 338 |
+
raise ProviderRateLimitError(
|
| 339 |
+
f"gemini rate limit retry-after {retry_s:.2f}s "
|
| 340 |
+
f"exceeds cap {self._max_retry_after_s:.2f}s: {body}"
|
| 341 |
+
) from exc
|
| 342 |
+
logging.getLogger("dataforge.bench.groq_client").warning(
|
| 343 |
+
"gemini_rate_limit attempt=%d retry_after_s=%.2f",
|
| 344 |
+
attempt + 1,
|
| 345 |
+
retry_s,
|
| 346 |
+
)
|
| 347 |
+
time.sleep(retry_s)
|
| 348 |
+
continue
|
| 349 |
+
except httpx.TimeoutException as exc:
|
| 350 |
+
raise TimeoutError(
|
| 351 |
+
f"gemini request timed out after {self._timeout_s:.1f} seconds."
|
| 352 |
+
) from exc
|
| 353 |
+
return dict(response.json())
|
| 354 |
+
if last_rate_limit_error is not None:
|
| 355 |
+
raise last_rate_limit_error
|
| 356 |
+
raise RuntimeError("gemini request failed without a response.")
|
| 357 |
+
|
| 358 |
+
def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
|
| 359 |
+
"""Send one benchmark completion request to Gemini."""
|
| 360 |
+
self._respect_spacing()
|
| 361 |
+
payload = self._post(messages)
|
| 362 |
+
self._last_success_at = time.monotonic()
|
| 363 |
+
|
| 364 |
+
warnings: list[str] = []
|
| 365 |
+
usage = payload.get("usageMetadata", {})
|
| 366 |
+
prompt_tokens = int(usage.get("promptTokenCount", 0)) if isinstance(usage, dict) else 0
|
| 367 |
+
completion_tokens = (
|
| 368 |
+
int(usage.get("candidatesTokenCount", 0)) if isinstance(usage, dict) else 0
|
| 369 |
+
)
|
| 370 |
+
if not usage:
|
| 371 |
+
warnings.append("missing_usage_payload")
|
| 372 |
+
logging.getLogger("dataforge.bench.groq_client").warning("gemini_missing_usage_payload")
|
| 373 |
+
|
| 374 |
+
try:
|
| 375 |
+
candidates = cast(list[dict[str, object]], payload["candidates"])
|
| 376 |
+
content = cast(dict[str, object], candidates[0]["content"])
|
| 377 |
+
parts = cast(list[dict[str, object]], content["parts"])
|
| 378 |
+
text = "".join(str(part.get("text", "")) for part in parts)
|
| 379 |
+
except (KeyError, IndexError, TypeError) as exc:
|
| 380 |
+
raise ValueError(f"Unexpected gemini response payload: {json.dumps(payload)}") from exc
|
| 381 |
+
return GroqCompletion(
|
| 382 |
+
text=text,
|
| 383 |
+
prompt_tokens=prompt_tokens,
|
| 384 |
+
completion_tokens=completion_tokens,
|
| 385 |
+
warnings=tuple(warnings),
|
| 386 |
+
)
|
dataforge/bench/methods.py
CHANGED
|
@@ -151,6 +151,40 @@ def _column_stats(
|
|
| 151 |
return stats
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def _extract_json_object(text: str) -> dict[str, object] | None:
|
| 155 |
"""Parse the first JSON object found in an LLM response string."""
|
| 156 |
stripped = text.strip()
|
|
@@ -158,6 +192,7 @@ def _extract_json_object(text: str) -> dict[str, object] | None:
|
|
| 158 |
stripped = stripped.strip("`")
|
| 159 |
if stripped.lower().startswith("json"):
|
| 160 |
stripped = stripped[4:].strip()
|
|
|
|
| 161 |
decoder = json.JSONDecoder()
|
| 162 |
for offset, char in enumerate(stripped):
|
| 163 |
if char != "{":
|
|
|
|
| 151 |
return stats
|
| 152 |
|
| 153 |
|
| 154 |
+
def _strip_json_line_comments(text: str) -> str:
|
| 155 |
+
"""Remove JavaScript-style line comments outside JSON strings."""
|
| 156 |
+
result: list[str] = []
|
| 157 |
+
in_string = False
|
| 158 |
+
escaped = False
|
| 159 |
+
index = 0
|
| 160 |
+
while index < len(text):
|
| 161 |
+
char = text[index]
|
| 162 |
+
next_char = text[index + 1] if index + 1 < len(text) else ""
|
| 163 |
+
if in_string:
|
| 164 |
+
result.append(char)
|
| 165 |
+
if escaped:
|
| 166 |
+
escaped = False
|
| 167 |
+
elif char == "\\":
|
| 168 |
+
escaped = True
|
| 169 |
+
elif char == '"':
|
| 170 |
+
in_string = False
|
| 171 |
+
index += 1
|
| 172 |
+
continue
|
| 173 |
+
if char == '"':
|
| 174 |
+
in_string = True
|
| 175 |
+
result.append(char)
|
| 176 |
+
index += 1
|
| 177 |
+
continue
|
| 178 |
+
if char == "/" and next_char == "/":
|
| 179 |
+
index += 2
|
| 180 |
+
while index < len(text) and text[index] not in "\r\n":
|
| 181 |
+
index += 1
|
| 182 |
+
continue
|
| 183 |
+
result.append(char)
|
| 184 |
+
index += 1
|
| 185 |
+
return "".join(result)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
def _extract_json_object(text: str) -> dict[str, object] | None:
|
| 189 |
"""Parse the first JSON object found in an LLM response string."""
|
| 190 |
stripped = text.strip()
|
|
|
|
| 192 |
stripped = stripped.strip("`")
|
| 193 |
if stripped.lower().startswith("json"):
|
| 194 |
stripped = stripped[4:].strip()
|
| 195 |
+
stripped = _strip_json_line_comments(stripped)
|
| 196 |
decoder = json.JSONDecoder()
|
| 197 |
for offset, char in enumerate(stripped):
|
| 198 |
if char != "{":
|
dataforge/bench/report.py
CHANGED
|
@@ -69,13 +69,14 @@ def _aggregate_across_datasets(aggregates: list[AggregateBenchmarkResult]) -> li
|
|
| 69 |
for method in methods:
|
| 70 |
ok_rows = grouped.get(method, [])
|
| 71 |
if not ok_rows:
|
| 72 |
-
rows.append([method, "Skipped", "Skipped", "Skipped", "Skipped", "Skipped"])
|
| 73 |
continue
|
| 74 |
p_mean = sum(row.precision_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 75 |
r_mean = sum(row.recall_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 76 |
f_mean = sum(row.f1_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 77 |
step_mean = sum(row.avg_steps_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 78 |
quota_mean = sum(row.quota_units_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
|
|
|
| 79 |
rows.append(
|
| 80 |
[
|
| 81 |
method,
|
|
@@ -84,6 +85,7 @@ def _aggregate_across_datasets(aggregates: list[AggregateBenchmarkResult]) -> li
|
|
| 84 |
f"{f_mean:.4f}",
|
| 85 |
f"{step_mean:.2f}",
|
| 86 |
f"{quota_mean:.4f}",
|
|
|
|
| 87 |
]
|
| 88 |
)
|
| 89 |
return rows
|
|
@@ -104,15 +106,13 @@ def build_readme_benchmark_block(agent_output: BenchmarkRunOutput, report_path:
|
|
| 104 |
"""Build the generated README benchmark summary block."""
|
| 105 |
rows = _aggregate_across_datasets(agent_output.aggregates)
|
| 106 |
table = _render_table(
|
| 107 |
-
["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units"],
|
| 108 |
rows,
|
| 109 |
)
|
| 110 |
skip_reasons = _collect_skip_reasons(agent_output.aggregates)
|
| 111 |
skip_note = ""
|
| 112 |
if skip_reasons:
|
| 113 |
-
skip_note = (
|
| 114 |
-
"\n\nSkipped methods in this run: " + "; ".join(skip_reasons)
|
| 115 |
-
)
|
| 116 |
return (
|
| 117 |
"Generated from `eval/results/agent_comparison.json`.\n\n"
|
| 118 |
f"{table}\n\n"
|
|
@@ -140,19 +140,28 @@ def render_benchmark_report(
|
|
| 140 |
_format_metric(row.f1_mean, row.f1_std),
|
| 141 |
_format_metric(row.avg_steps_mean, row.avg_steps_std),
|
| 142 |
_format_metric(row.quota_units_mean, row.quota_units_std),
|
|
|
|
| 143 |
]
|
| 144 |
for row in rows
|
| 145 |
]
|
| 146 |
per_dataset_sections.append(
|
| 147 |
f"### {dataset.title()}\n\n"
|
| 148 |
+ _render_table(
|
| 149 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
table_rows,
|
| 151 |
)
|
| 152 |
)
|
| 153 |
|
| 154 |
local_summary = _render_table(
|
| 155 |
-
["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units"],
|
| 156 |
_aggregate_across_datasets(agent_output.aggregates),
|
| 157 |
)
|
| 158 |
|
|
@@ -179,11 +188,7 @@ def render_benchmark_report(
|
|
| 179 |
skip_reasons = _collect_skip_reasons(agent_output.aggregates)
|
| 180 |
skip_note = ""
|
| 181 |
if skip_reasons:
|
| 182 |
-
skip_note = (
|
| 183 |
-
"\nSkipped methods in this reproduced run: "
|
| 184 |
-
+ "; ".join(skip_reasons)
|
| 185 |
-
+ "\n"
|
| 186 |
-
)
|
| 187 |
|
| 188 |
method_values = agent_output.metadata.get("methods", [])
|
| 189 |
dataset_values = agent_output.metadata.get("datasets", [])
|
|
@@ -203,6 +208,7 @@ def render_benchmark_report(
|
|
| 203 |
f"- Datasets: {', '.join(datasets)}\n"
|
| 204 |
f"- Seeds: {seeds}\n"
|
| 205 |
"- Free-tier quota units: `max(llm_calls / 1000, (prompt_tokens + completion_tokens) / 100000)`\n"
|
|
|
|
| 206 |
f"{skip_note}\n"
|
| 207 |
"## Cross-Dataset Local Results\n\n"
|
| 208 |
f"{local_summary}\n\n"
|
|
@@ -216,7 +222,7 @@ def render_benchmark_report(
|
|
| 216 |
sota_rows,
|
| 217 |
)
|
| 218 |
+ "\n\n## Methodology\n\n"
|
| 219 |
-
+ "Local rows are reproduced from generated JSON. Citation-only SOTA rows are copied from literature and are not rerun in this repository.
|
| 220 |
)
|
| 221 |
|
| 222 |
|
|
|
|
| 69 |
for method in methods:
|
| 70 |
ok_rows = grouped.get(method, [])
|
| 71 |
if not ok_rows:
|
| 72 |
+
rows.append([method, "Skipped", "Skipped", "Skipped", "Skipped", "Skipped", "Skipped"])
|
| 73 |
continue
|
| 74 |
p_mean = sum(row.precision_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 75 |
r_mean = sum(row.recall_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 76 |
f_mean = sum(row.f1_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 77 |
step_mean = sum(row.avg_steps_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 78 |
quota_mean = sum(row.quota_units_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 79 |
+
gpu_hours_mean = sum(row.gpu_hours_mean or 0.0 for row in ok_rows) / len(ok_rows)
|
| 80 |
rows.append(
|
| 81 |
[
|
| 82 |
method,
|
|
|
|
| 85 |
f"{f_mean:.4f}",
|
| 86 |
f"{step_mean:.2f}",
|
| 87 |
f"{quota_mean:.4f}",
|
| 88 |
+
f"{gpu_hours_mean:.4f}",
|
| 89 |
]
|
| 90 |
)
|
| 91 |
return rows
|
|
|
|
| 106 |
"""Build the generated README benchmark summary block."""
|
| 107 |
rows = _aggregate_across_datasets(agent_output.aggregates)
|
| 108 |
table = _render_table(
|
| 109 |
+
["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units", "GPU Hours"],
|
| 110 |
rows,
|
| 111 |
)
|
| 112 |
skip_reasons = _collect_skip_reasons(agent_output.aggregates)
|
| 113 |
skip_note = ""
|
| 114 |
if skip_reasons:
|
| 115 |
+
skip_note = "\n\nSkipped methods in this run: " + "; ".join(skip_reasons)
|
|
|
|
|
|
|
| 116 |
return (
|
| 117 |
"Generated from `eval/results/agent_comparison.json`.\n\n"
|
| 118 |
f"{table}\n\n"
|
|
|
|
| 140 |
_format_metric(row.f1_mean, row.f1_std),
|
| 141 |
_format_metric(row.avg_steps_mean, row.avg_steps_std),
|
| 142 |
_format_metric(row.quota_units_mean, row.quota_units_std),
|
| 143 |
+
_format_metric(row.gpu_hours_mean, row.gpu_hours_std),
|
| 144 |
]
|
| 145 |
for row in rows
|
| 146 |
]
|
| 147 |
per_dataset_sections.append(
|
| 148 |
f"### {dataset.title()}\n\n"
|
| 149 |
+ _render_table(
|
| 150 |
+
[
|
| 151 |
+
"Method",
|
| 152 |
+
"Precision",
|
| 153 |
+
"Recall",
|
| 154 |
+
"F1",
|
| 155 |
+
"Avg Steps",
|
| 156 |
+
"Quota Units",
|
| 157 |
+
"GPU Hours",
|
| 158 |
+
],
|
| 159 |
table_rows,
|
| 160 |
)
|
| 161 |
)
|
| 162 |
|
| 163 |
local_summary = _render_table(
|
| 164 |
+
["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units", "GPU Hours"],
|
| 165 |
_aggregate_across_datasets(agent_output.aggregates),
|
| 166 |
)
|
| 167 |
|
|
|
|
| 188 |
skip_reasons = _collect_skip_reasons(agent_output.aggregates)
|
| 189 |
skip_note = ""
|
| 190 |
if skip_reasons:
|
| 191 |
+
skip_note = "\nSkipped methods in this reproduced run: " + "; ".join(skip_reasons) + "\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
method_values = agent_output.metadata.get("methods", [])
|
| 194 |
dataset_values = agent_output.metadata.get("datasets", [])
|
|
|
|
| 208 |
f"- Datasets: {', '.join(datasets)}\n"
|
| 209 |
f"- Seeds: {seeds}\n"
|
| 210 |
"- Free-tier quota units: `max(llm_calls / 1000, (prompt_tokens + completion_tokens) / 100000)`\n"
|
| 211 |
+
"- GRPO compute cost is reported as free-tier GPU-hours, not dollars.\n"
|
| 212 |
f"{skip_note}\n"
|
| 213 |
"## Cross-Dataset Local Results\n\n"
|
| 214 |
f"{local_summary}\n\n"
|
|
|
|
| 222 |
sota_rows,
|
| 223 |
)
|
| 224 |
+ "\n\n## Methodology\n\n"
|
| 225 |
+
+ "Local rows are reproduced from generated JSON. Citation-only SOTA rows are copied from literature and are not rerun in this repository. LLM quota units are free-tier fractions; GRPO compute cost is GPU-hours, not dollars.\n"
|
| 226 |
)
|
| 227 |
|
| 228 |
|
dataforge/bench/runner.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import os
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
from dotenv import load_dotenv
|
|
@@ -96,18 +97,21 @@ def run_agent_comparison(
|
|
| 96 |
output_json: Path,
|
| 97 |
really_run_big_bench: bool,
|
| 98 |
cache_root: Path | None = None,
|
|
|
|
| 99 |
) -> BenchmarkRunOutput:
|
| 100 |
"""Run the selected benchmark methods across real-world datasets."""
|
| 101 |
load_dotenv()
|
| 102 |
_validate_inputs(methods, datasets, seeds)
|
| 103 |
|
| 104 |
estimated_calls = estimate_llm_calls(methods=methods, datasets=datasets, seeds=seeds)
|
|
|
|
|
|
|
| 105 |
validate_estimated_calls(
|
| 106 |
estimated_calls=estimated_calls,
|
| 107 |
really_run_big_bench=really_run_big_bench,
|
| 108 |
)
|
| 109 |
|
| 110 |
-
reproduction_command = _reproduction_command(methods, datasets, seeds)
|
| 111 |
records: list[SeedBenchmarkResult] = []
|
| 112 |
loaded_datasets = {
|
| 113 |
dataset_name: load_real_world_dataset(dataset_name, cache_root=cache_root)
|
|
@@ -116,16 +120,45 @@ def run_agent_comparison(
|
|
| 116 |
|
| 117 |
llm_methods_requested = any(method.startswith("llm_") for method in methods)
|
| 118 |
skip_reason = _llm_skip_reason() if llm_methods_requested else None
|
| 119 |
-
client =
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
for dataset_name in datasets:
|
| 126 |
dataset = loaded_datasets[dataset_name]
|
| 127 |
for method in methods:
|
| 128 |
for seed in range(seeds):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
if method == "random":
|
| 130 |
result = run_random_episode(dataset, seed=seed)
|
| 131 |
elif method == "heuristic":
|
|
@@ -159,6 +192,12 @@ def run_agent_comparison(
|
|
| 159 |
if method == "heuristic":
|
| 160 |
result = result.model_copy(update={"seed": seed})
|
| 161 |
records.append(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results(
|
| 164 |
records, seeds_requested=seeds
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import os
|
| 6 |
+
import sys
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
from dotenv import load_dotenv
|
|
|
|
| 97 |
output_json: Path,
|
| 98 |
really_run_big_bench: bool,
|
| 99 |
cache_root: Path | None = None,
|
| 100 |
+
reproduction_command: str | None = None,
|
| 101 |
) -> BenchmarkRunOutput:
|
| 102 |
"""Run the selected benchmark methods across real-world datasets."""
|
| 103 |
load_dotenv()
|
| 104 |
_validate_inputs(methods, datasets, seeds)
|
| 105 |
|
| 106 |
estimated_calls = estimate_llm_calls(methods=methods, datasets=datasets, seeds=seeds)
|
| 107 |
+
# Validate call budget before any client instantiation or dataset loads that could
|
| 108 |
+
# trigger network access in tests with environment variables set.
|
| 109 |
validate_estimated_calls(
|
| 110 |
estimated_calls=estimated_calls,
|
| 111 |
really_run_big_bench=really_run_big_bench,
|
| 112 |
)
|
| 113 |
|
| 114 |
+
reproduction_command = reproduction_command or _reproduction_command(methods, datasets, seeds)
|
| 115 |
records: list[SeedBenchmarkResult] = []
|
| 116 |
loaded_datasets = {
|
| 117 |
dataset_name: load_real_world_dataset(dataset_name, cache_root=cache_root)
|
|
|
|
| 120 |
|
| 121 |
llm_methods_requested = any(method.startswith("llm_") for method in methods)
|
| 122 |
skip_reason = _llm_skip_reason() if llm_methods_requested else None
|
| 123 |
+
client = None
|
| 124 |
+
if llm_methods_requested and skip_reason is None:
|
| 125 |
+
# Allow env-driven tuning for tiny CI checks.
|
| 126 |
+
model = os.environ.get("DATAFORGE_GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 127 |
+
try:
|
| 128 |
+
min_interval_s = float(os.environ.get("DATAFORGE_GROQ_MIN_INTERVAL_S", "1.0"))
|
| 129 |
+
except ValueError:
|
| 130 |
+
min_interval_s = 1.0
|
| 131 |
+
try:
|
| 132 |
+
timeout_s = float(os.environ.get("DATAFORGE_GROQ_TIMEOUT_S", "30"))
|
| 133 |
+
except ValueError:
|
| 134 |
+
timeout_s = 30.0
|
| 135 |
+
try:
|
| 136 |
+
max_tokens = int(os.environ.get("DATAFORGE_GROQ_MAX_TOKENS", "256"))
|
| 137 |
+
except ValueError:
|
| 138 |
+
max_tokens = 256
|
| 139 |
+
try:
|
| 140 |
+
max_retries = int(os.environ.get("DATAFORGE_GROQ_MAX_RETRIES", "3"))
|
| 141 |
+
except ValueError:
|
| 142 |
+
max_retries = 3
|
| 143 |
+
client = GroqBenchClient(
|
| 144 |
+
api_key=os.environ["GROQ_API_KEY"],
|
| 145 |
+
model=model,
|
| 146 |
+
min_interval_s=min_interval_s,
|
| 147 |
+
max_tokens=max_tokens,
|
| 148 |
+
max_retries=max_retries,
|
| 149 |
+
timeout_s=timeout_s,
|
| 150 |
+
)
|
| 151 |
|
| 152 |
for dataset_name in datasets:
|
| 153 |
dataset = loaded_datasets[dataset_name]
|
| 154 |
for method in methods:
|
| 155 |
for seed in range(seeds):
|
| 156 |
+
if os.environ.get("DATAFORGE_BENCH_VERBOSE"):
|
| 157 |
+
print(
|
| 158 |
+
f"[dataforge bench] start method={method} dataset={dataset_name} seed={seed}",
|
| 159 |
+
file=sys.stderr,
|
| 160 |
+
flush=True,
|
| 161 |
+
)
|
| 162 |
if method == "random":
|
| 163 |
result = run_random_episode(dataset, seed=seed)
|
| 164 |
elif method == "heuristic":
|
|
|
|
| 192 |
if method == "heuristic":
|
| 193 |
result = result.model_copy(update={"seed": seed})
|
| 194 |
records.append(result)
|
| 195 |
+
if os.environ.get("DATAFORGE_BENCH_VERBOSE"):
|
| 196 |
+
print(
|
| 197 |
+
f"[dataforge bench] done method={method} dataset={dataset_name} seed={seed} status={result.status}",
|
| 198 |
+
file=sys.stderr,
|
| 199 |
+
flush=True,
|
| 200 |
+
)
|
| 201 |
|
| 202 |
aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results(
|
| 203 |
records, seeds_requested=seeds
|
dataforge/causal/__init__.py
CHANGED
|
@@ -1 +1,21 @@
|
|
| 1 |
-
"""Causal analysis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Causal analysis primitives for DataForge root-cause diagnosis."""
|
| 2 |
+
|
| 3 |
+
from dataforge.causal.dag import CausalDAG, CausalEdge
|
| 4 |
+
from dataforge.causal.pc import CausalDiscoveryResult, discover_causal_dag
|
| 5 |
+
from dataforge.causal.root_cause import (
|
| 6 |
+
CausalRootCauseAnalyzer,
|
| 7 |
+
ErrorEvidence,
|
| 8 |
+
RootCauseResult,
|
| 9 |
+
minimal_root_set,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"CausalDAG",
|
| 14 |
+
"CausalDiscoveryResult",
|
| 15 |
+
"CausalEdge",
|
| 16 |
+
"CausalRootCauseAnalyzer",
|
| 17 |
+
"ErrorEvidence",
|
| 18 |
+
"RootCauseResult",
|
| 19 |
+
"discover_causal_dag",
|
| 20 |
+
"minimal_root_set",
|
| 21 |
+
]
|
dataforge/causal/dag.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Column-level causal DAG utilities for root-cause analysis."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import networkx as nx # type: ignore[import-untyped]
|
| 9 |
+
|
| 10 |
+
__all__ = ["CausalDAG", "CausalEdge"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class CausalEdge:
|
| 15 |
+
"""Metadata for a directed causal edge.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
source: Source column name.
|
| 19 |
+
target: Target column name.
|
| 20 |
+
confidence: Confidence in the directed influence, from 0.0 to 1.0.
|
| 21 |
+
provenance: Human-readable source of the edge.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
source: str
|
| 25 |
+
target: str
|
| 26 |
+
confidence: float
|
| 27 |
+
provenance: str
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CausalDAG:
|
| 31 |
+
"""Acyclic directed graph whose nodes are dataset columns.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
nodes: Optional initial column names.
|
| 35 |
+
|
| 36 |
+
Example:
|
| 37 |
+
>>> dag = CausalDAG(["discount_pct", "order_total"])
|
| 38 |
+
>>> dag.add_edge("discount_pct", "order_total", confidence=0.9, provenance="fd")
|
| 39 |
+
>>> dag.is_reachable("discount_pct", "order_total")
|
| 40 |
+
True
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, nodes: list[str] | tuple[str, ...] = ()) -> None:
|
| 44 |
+
self._graph: nx.DiGraph[Any] = nx.DiGraph()
|
| 45 |
+
self._graph.add_nodes_from(nodes)
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def nodes(self) -> tuple[str, ...]:
|
| 49 |
+
"""Return graph nodes in insertion order."""
|
| 50 |
+
return tuple(str(node) for node in self._graph.nodes)
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def edges(self) -> tuple[CausalEdge, ...]:
|
| 54 |
+
"""Return directed edges with metadata."""
|
| 55 |
+
result: list[CausalEdge] = []
|
| 56 |
+
for source, target, attrs in self._graph.edges(data=True):
|
| 57 |
+
result.append(
|
| 58 |
+
CausalEdge(
|
| 59 |
+
source=str(source),
|
| 60 |
+
target=str(target),
|
| 61 |
+
confidence=float(attrs.get("confidence", 0.0)),
|
| 62 |
+
provenance=str(attrs.get("provenance", "unknown")),
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
return tuple(result)
|
| 66 |
+
|
| 67 |
+
def add_node(self, column: str) -> None:
|
| 68 |
+
"""Add a column node if it is not already present.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
column: Column name.
|
| 72 |
+
"""
|
| 73 |
+
self._graph.add_node(column)
|
| 74 |
+
|
| 75 |
+
def add_edge(
|
| 76 |
+
self,
|
| 77 |
+
source: str,
|
| 78 |
+
target: str,
|
| 79 |
+
*,
|
| 80 |
+
confidence: float,
|
| 81 |
+
provenance: str,
|
| 82 |
+
) -> None:
|
| 83 |
+
"""Add a directed causal edge while preserving acyclicity.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
source: Source column name.
|
| 87 |
+
target: Target column name.
|
| 88 |
+
confidence: Confidence score from 0.0 to 1.0.
|
| 89 |
+
provenance: Source of the edge.
|
| 90 |
+
|
| 91 |
+
Raises:
|
| 92 |
+
ValueError: If the edge is self-referential or creates a cycle.
|
| 93 |
+
"""
|
| 94 |
+
if source == target:
|
| 95 |
+
raise ValueError("Causal DAG does not allow self-edges")
|
| 96 |
+
self._graph.add_node(source)
|
| 97 |
+
self._graph.add_node(target)
|
| 98 |
+
if nx.has_path(self._graph, target, source):
|
| 99 |
+
raise ValueError(f"Adding {source!r} -> {target!r} would create a cycle")
|
| 100 |
+
bounded = max(0.0, min(1.0, confidence))
|
| 101 |
+
self._graph.add_edge(source, target, confidence=bounded, provenance=provenance)
|
| 102 |
+
|
| 103 |
+
def successors(self, column: str) -> tuple[str, ...]:
|
| 104 |
+
"""Return direct downstream columns for a node.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
column: Column name.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
A tuple of direct successor column names.
|
| 111 |
+
"""
|
| 112 |
+
if column not in self._graph:
|
| 113 |
+
return ()
|
| 114 |
+
return tuple(str(node) for node in self._graph.successors(column))
|
| 115 |
+
|
| 116 |
+
def is_reachable(self, source: str, target: str) -> bool:
|
| 117 |
+
"""Return whether target is reachable from source.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
source: Source column name.
|
| 121 |
+
target: Target column name.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
True if source equals target or a directed path exists.
|
| 125 |
+
"""
|
| 126 |
+
if source == target:
|
| 127 |
+
return True
|
| 128 |
+
if source not in self._graph or target not in self._graph:
|
| 129 |
+
return False
|
| 130 |
+
return bool(nx.has_path(self._graph, source, target))
|
| 131 |
+
|
| 132 |
+
def path_confidence(self, source: str, target: str) -> float:
|
| 133 |
+
"""Return the weakest-edge confidence on the shortest path.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
source: Source column name.
|
| 137 |
+
target: Target column name.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Confidence in [0.0, 1.0], or 0.0 when no path exists.
|
| 141 |
+
"""
|
| 142 |
+
if source == target:
|
| 143 |
+
return 1.0
|
| 144 |
+
if not self.is_reachable(source, target):
|
| 145 |
+
return 0.0
|
| 146 |
+
path = nx.shortest_path(self._graph, source, target)
|
| 147 |
+
confidences = [
|
| 148 |
+
float(self._graph.edges[path[i], path[i + 1]].get("confidence", 0.0))
|
| 149 |
+
for i in range(len(path) - 1)
|
| 150 |
+
]
|
| 151 |
+
return min(confidences, default=0.0)
|
| 152 |
+
|
| 153 |
+
def minimal_root_columns(self, columns: list[str] | tuple[str, ...]) -> tuple[str, ...]:
|
| 154 |
+
"""Return selected columns that are not downstream of another selection.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
columns: Selected error columns.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Minimal root columns in first-seen order.
|
| 161 |
+
"""
|
| 162 |
+
unique: list[str] = []
|
| 163 |
+
for column in columns:
|
| 164 |
+
if column not in unique:
|
| 165 |
+
unique.append(column)
|
| 166 |
+
|
| 167 |
+
roots: list[str] = []
|
| 168 |
+
for column in unique:
|
| 169 |
+
has_upstream = any(
|
| 170 |
+
other != column and self.is_reachable(other, column) for other in unique
|
| 171 |
+
)
|
| 172 |
+
if not has_upstream:
|
| 173 |
+
roots.append(column)
|
| 174 |
+
return tuple(roots)
|
dataforge/causal/pc.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PC-based causal DAG discovery with functional-dependency priors."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from scipy.stats import chi2_contingency # type: ignore[import-untyped]
|
| 11 |
+
|
| 12 |
+
from dataforge.causal.dag import CausalDAG
|
| 13 |
+
from dataforge.verifier.schema import Schema
|
| 14 |
+
|
| 15 |
+
__all__ = ["CausalDiscoveryResult", "discover_causal_dag"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class CausalDiscoveryResult:
|
| 20 |
+
"""Result of causal discovery.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
dag: Directed acyclic graph over columns.
|
| 24 |
+
confidence_report: Column-pair confidence or diagnostic metadata.
|
| 25 |
+
warnings: Non-fatal discovery warnings.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
dag: CausalDAG
|
| 29 |
+
confidence_report: dict[str, float] = field(default_factory=dict)
|
| 30 |
+
warnings: tuple[str, ...] = ()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def discover_causal_dag(
|
| 34 |
+
df: pd.DataFrame,
|
| 35 |
+
schema: Schema | None = None,
|
| 36 |
+
*,
|
| 37 |
+
alpha: float = 0.05,
|
| 38 |
+
) -> CausalDiscoveryResult:
|
| 39 |
+
"""Infer a deterministic causal DAG from tabular data and FD priors.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
df: Input DataFrame.
|
| 43 |
+
schema: Optional declared schema with functional dependencies.
|
| 44 |
+
alpha: Significance threshold for independence checks.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
CausalDiscoveryResult. A DAG is returned even if PC orientation is
|
| 48 |
+
underdetermined; low-confidence edges are tagged as such.
|
| 49 |
+
"""
|
| 50 |
+
columns = [str(column) for column in df.columns]
|
| 51 |
+
dag = CausalDAG(columns)
|
| 52 |
+
report: dict[str, float] = {}
|
| 53 |
+
warnings: list[str] = []
|
| 54 |
+
|
| 55 |
+
if schema is not None:
|
| 56 |
+
for fd in schema.functional_dependencies:
|
| 57 |
+
for determinant in fd.determinant:
|
| 58 |
+
_try_add_edge(
|
| 59 |
+
dag,
|
| 60 |
+
determinant,
|
| 61 |
+
fd.dependent,
|
| 62 |
+
confidence=0.95,
|
| 63 |
+
provenance="functional_dependency_prior",
|
| 64 |
+
warnings=warnings,
|
| 65 |
+
)
|
| 66 |
+
report[f"{determinant}->{fd.dependent}"] = 0.95
|
| 67 |
+
|
| 68 |
+
cleaned = _prepare_for_pc(df)
|
| 69 |
+
pc_edges, pc_warning = _run_causal_learn_pc(cleaned.to_numpy(), columns, alpha)
|
| 70 |
+
if pc_warning:
|
| 71 |
+
warnings.append(pc_warning)
|
| 72 |
+
for source, target in pc_edges:
|
| 73 |
+
_try_add_edge(
|
| 74 |
+
dag,
|
| 75 |
+
source,
|
| 76 |
+
target,
|
| 77 |
+
confidence=0.55,
|
| 78 |
+
provenance="causal_learn_pc",
|
| 79 |
+
warnings=warnings,
|
| 80 |
+
)
|
| 81 |
+
report.setdefault(f"{source}->{target}", 0.55)
|
| 82 |
+
|
| 83 |
+
for source, target, confidence in _pairwise_dependency_edges(df, alpha):
|
| 84 |
+
_try_add_edge(
|
| 85 |
+
dag,
|
| 86 |
+
source,
|
| 87 |
+
target,
|
| 88 |
+
confidence=confidence,
|
| 89 |
+
provenance="pairwise_ci_fallback",
|
| 90 |
+
warnings=warnings,
|
| 91 |
+
)
|
| 92 |
+
report.setdefault(f"{source}->{target}", confidence)
|
| 93 |
+
|
| 94 |
+
return CausalDiscoveryResult(dag=dag, confidence_report=report, warnings=tuple(warnings))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _prepare_for_pc(df: pd.DataFrame) -> pd.DataFrame:
|
| 98 |
+
"""Return numeric data with no NaN values for causal-learn PC."""
|
| 99 |
+
prepared = pd.DataFrame(index=df.index)
|
| 100 |
+
for column in df.columns:
|
| 101 |
+
numeric = pd.to_numeric(df[column], errors="coerce")
|
| 102 |
+
if numeric.notna().sum() >= max(2, int(0.5 * len(df))):
|
| 103 |
+
fill = float(numeric.median()) if numeric.notna().any() else 0.0
|
| 104 |
+
prepared[str(column)] = numeric.fillna(fill)
|
| 105 |
+
else:
|
| 106 |
+
codes, _ = pd.factorize(df[column].astype("string").fillna("<missing>"), sort=True)
|
| 107 |
+
prepared[str(column)] = codes.astype(float)
|
| 108 |
+
return prepared.fillna(0.0)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _run_causal_learn_pc(
|
| 112 |
+
data: np.ndarray[Any, Any], columns: list[str], alpha: float
|
| 113 |
+
) -> tuple[list[tuple[str, str]], str | None]:
|
| 114 |
+
"""Run causal-learn PC and return deterministic directed edges."""
|
| 115 |
+
try:
|
| 116 |
+
from causallearn.search.ConstraintBased.PC import pc # type: ignore[import-untyped]
|
| 117 |
+
|
| 118 |
+
result = pc(data, alpha=alpha, indep_test="fisherz", stable=True, show_progress=False)
|
| 119 |
+
except Exception as exc:
|
| 120 |
+
return [], f"causal-learn PC unavailable or failed: {exc}"
|
| 121 |
+
|
| 122 |
+
matrix = getattr(getattr(result, "G", None), "graph", None)
|
| 123 |
+
if matrix is None:
|
| 124 |
+
return [], "causal-learn PC returned no adjacency matrix"
|
| 125 |
+
|
| 126 |
+
edges: list[tuple[str, str]] = []
|
| 127 |
+
arr = np.asarray(matrix)
|
| 128 |
+
for i, source in enumerate(columns):
|
| 129 |
+
for j, target in enumerate(columns):
|
| 130 |
+
if i >= j or i >= arr.shape[0] or j >= arr.shape[1]:
|
| 131 |
+
continue
|
| 132 |
+
if arr[i, j] != 0 or arr[j, i] != 0:
|
| 133 |
+
edges.append((source, target))
|
| 134 |
+
return edges, None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _pairwise_dependency_edges(df: pd.DataFrame, alpha: float) -> list[tuple[str, str, float]]:
|
| 138 |
+
"""Return deterministic low-confidence edges for dependent column pairs."""
|
| 139 |
+
columns = [str(column) for column in df.columns]
|
| 140 |
+
edges: list[tuple[str, str, float]] = []
|
| 141 |
+
for i, source in enumerate(columns):
|
| 142 |
+
for target in columns[i + 1 :]:
|
| 143 |
+
p_value = _pairwise_p_value(df[source], df[target])
|
| 144 |
+
if p_value < alpha:
|
| 145 |
+
confidence = max(0.25, min(0.75, 1.0 - p_value))
|
| 146 |
+
edges.append((source, target, round(confidence, 4)))
|
| 147 |
+
return edges
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _pairwise_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float:
|
| 151 |
+
"""Return a p-value using categorical, continuous, or mixed tests."""
|
| 152 |
+
left_numeric = pd.to_numeric(left, errors="coerce")
|
| 153 |
+
right_numeric = pd.to_numeric(right, errors="coerce")
|
| 154 |
+
left_cont = left_numeric.notna().sum() >= max(5, int(0.8 * len(left)))
|
| 155 |
+
right_cont = right_numeric.notna().sum() >= max(5, int(0.8 * len(right)))
|
| 156 |
+
|
| 157 |
+
if left_cont and right_cont:
|
| 158 |
+
return _hsic_p_value(
|
| 159 |
+
left_numeric.fillna(left_numeric.median()), right_numeric.fillna(right_numeric.median())
|
| 160 |
+
)
|
| 161 |
+
if not left_cont and not right_cont:
|
| 162 |
+
return _chi_squared_p_value(left, right)
|
| 163 |
+
return _mutual_information_p_value(left, right)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _chi_squared_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float:
|
| 167 |
+
"""Return chi-squared independence p-value for categorical pairs."""
|
| 168 |
+
table = pd.crosstab(
|
| 169 |
+
left.astype("string").fillna("<missing>"), right.astype("string").fillna("<missing>")
|
| 170 |
+
)
|
| 171 |
+
if table.shape[0] < 2 or table.shape[1] < 2:
|
| 172 |
+
return 1.0
|
| 173 |
+
_, p_value, _, _ = chi2_contingency(table)
|
| 174 |
+
return float(p_value)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _hsic_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float:
|
| 178 |
+
"""Return HSIC p-value for continuous pairs, with correlation fallback."""
|
| 179 |
+
x = left.to_numpy(dtype=float).reshape(-1, 1)
|
| 180 |
+
y = right.to_numpy(dtype=float).reshape(-1, 1)
|
| 181 |
+
try:
|
| 182 |
+
from hyppo.independence import Hsic # type: ignore[import-untyped]
|
| 183 |
+
|
| 184 |
+
_, p_value = Hsic().test(x, y, reps=100, auto=True)
|
| 185 |
+
return float(p_value)
|
| 186 |
+
except Exception:
|
| 187 |
+
corr = abs(float(np.corrcoef(x[:, 0], y[:, 0])[0, 1]))
|
| 188 |
+
return 0.0 if corr > 0.75 else 1.0
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _mutual_information_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float:
|
| 192 |
+
"""Return a bounded pseudo p-value from binned mutual information."""
|
| 193 |
+
left_codes = _codes(left)
|
| 194 |
+
right_codes = _codes(right)
|
| 195 |
+
table = pd.crosstab(left_codes, right_codes)
|
| 196 |
+
total = float(table.to_numpy().sum())
|
| 197 |
+
if total == 0.0 or table.shape[0] < 2 or table.shape[1] < 2:
|
| 198 |
+
return 1.0
|
| 199 |
+
joint = table.to_numpy(dtype=float) / total
|
| 200 |
+
px = joint.sum(axis=1, keepdims=True)
|
| 201 |
+
py = joint.sum(axis=0, keepdims=True)
|
| 202 |
+
expected = px @ py
|
| 203 |
+
mask = joint > 0
|
| 204 |
+
mi = float((joint[mask] * np.log(joint[mask] / expected[mask])).sum())
|
| 205 |
+
return float(np.exp(-mi))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _codes(series: pd.Series[Any]) -> np.ndarray[Any, Any]:
|
| 209 |
+
"""Return stable integer codes for a mixed-type series."""
|
| 210 |
+
numeric = pd.to_numeric(series, errors="coerce")
|
| 211 |
+
if numeric.notna().sum() >= max(5, int(0.8 * len(series))):
|
| 212 |
+
return pd.qcut(
|
| 213 |
+
numeric.fillna(numeric.median()), q=4, duplicates="drop"
|
| 214 |
+
).cat.codes.to_numpy()
|
| 215 |
+
codes, _ = pd.factorize(series.astype("string").fillna("<missing>"), sort=True)
|
| 216 |
+
return codes
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _try_add_edge(
|
| 220 |
+
dag: CausalDAG,
|
| 221 |
+
source: str,
|
| 222 |
+
target: str,
|
| 223 |
+
*,
|
| 224 |
+
confidence: float,
|
| 225 |
+
provenance: str,
|
| 226 |
+
warnings: list[str],
|
| 227 |
+
) -> None:
|
| 228 |
+
"""Add an edge or record the cycle warning."""
|
| 229 |
+
try:
|
| 230 |
+
dag.add_edge(source, target, confidence=confidence, provenance=provenance)
|
| 231 |
+
except ValueError as exc:
|
| 232 |
+
warnings.append(str(exc))
|
dataforge/causal/root_cause.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal root-cause selection over detected errors and a causal DAG."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Protocol
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
from dataforge.causal.dag import CausalDAG
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"CausalRootCauseAnalyzer",
|
| 13 |
+
"ErrorEvidence",
|
| 14 |
+
"RootCauseResult",
|
| 15 |
+
"evidence_from_issue",
|
| 16 |
+
"minimal_root_set",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class _IssueLike(Protocol):
|
| 21 |
+
"""Protocol for objects with row/column issue fields."""
|
| 22 |
+
|
| 23 |
+
row: int
|
| 24 |
+
column: str
|
| 25 |
+
issue_type: str
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ErrorEvidence(BaseModel):
|
| 29 |
+
"""Column-mapped detected error used for causal root-cause analysis.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
index: Zero-based error index in the caller's selected issue list.
|
| 33 |
+
row: Row index where the error was detected.
|
| 34 |
+
column: Column where the error was detected.
|
| 35 |
+
issue_type: Machine-readable issue type.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
index: int = Field(ge=0)
|
| 39 |
+
row: int = Field(ge=0)
|
| 40 |
+
column: str = Field(min_length=1)
|
| 41 |
+
issue_type: str = Field(min_length=1)
|
| 42 |
+
|
| 43 |
+
model_config = {"frozen": True}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RootCauseResult(BaseModel):
|
| 47 |
+
"""Structured result returned by the root-cause analyzer.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
root_indices: Minimal selected error indices.
|
| 51 |
+
root_columns: Root columns corresponding to root_indices.
|
| 52 |
+
covered_indices: Selected error indices covered by the root set.
|
| 53 |
+
confidence: Mean path confidence from roots to covered errors.
|
| 54 |
+
explanation: Human-readable explanation of the selected roots.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
root_indices: list[int]
|
| 58 |
+
root_columns: list[str]
|
| 59 |
+
covered_indices: list[int]
|
| 60 |
+
confidence: float
|
| 61 |
+
explanation: str
|
| 62 |
+
|
| 63 |
+
model_config = {"frozen": True}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class CausalRootCauseAnalyzer:
|
| 67 |
+
"""Compute minimal root causes for selected detected errors.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
dag: Column-level causal DAG.
|
| 71 |
+
|
| 72 |
+
Example:
|
| 73 |
+
>>> dag = CausalDAG(["discount_pct", "order_total"])
|
| 74 |
+
>>> dag.add_edge("discount_pct", "order_total", confidence=0.9, provenance="formula")
|
| 75 |
+
>>> errors = [
|
| 76 |
+
... ErrorEvidence(index=0, row=1, column="discount_pct", issue_type="bad"),
|
| 77 |
+
... ErrorEvidence(index=1, row=1, column="order_total", issue_type="bad"),
|
| 78 |
+
... ]
|
| 79 |
+
>>> CausalRootCauseAnalyzer(dag).analyze(errors).root_indices
|
| 80 |
+
[0]
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, dag: CausalDAG) -> None:
|
| 84 |
+
self._dag = dag
|
| 85 |
+
|
| 86 |
+
def analyze(self, errors: list[ErrorEvidence] | tuple[ErrorEvidence, ...]) -> RootCauseResult:
|
| 87 |
+
"""Return the minimal root set for the selected errors.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
errors: Selected detected errors.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
RootCauseResult with roots, coverage, confidence, and explanation.
|
| 94 |
+
"""
|
| 95 |
+
if not errors:
|
| 96 |
+
return RootCauseResult(
|
| 97 |
+
root_indices=[],
|
| 98 |
+
root_columns=[],
|
| 99 |
+
covered_indices=[],
|
| 100 |
+
confidence=0.0,
|
| 101 |
+
explanation="No errors were supplied.",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
roots: list[ErrorEvidence] = []
|
| 105 |
+
for candidate in errors:
|
| 106 |
+
if not self._has_upstream_selected_error(candidate, errors):
|
| 107 |
+
roots.append(candidate)
|
| 108 |
+
|
| 109 |
+
covered: list[int] = []
|
| 110 |
+
path_confidences: list[float] = []
|
| 111 |
+
for error in errors:
|
| 112 |
+
for root in roots:
|
| 113 |
+
if root.column == error.column or self._dag.is_reachable(root.column, error.column):
|
| 114 |
+
covered.append(error.index)
|
| 115 |
+
path_confidences.append(self._dag.path_confidence(root.column, error.column))
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
confidence = (
|
| 119 |
+
round(sum(path_confidences) / len(path_confidences), 4) if path_confidences else 0.0
|
| 120 |
+
)
|
| 121 |
+
root_columns = [root.column for root in roots]
|
| 122 |
+
return RootCauseResult(
|
| 123 |
+
root_indices=[root.index for root in roots],
|
| 124 |
+
root_columns=root_columns,
|
| 125 |
+
covered_indices=covered,
|
| 126 |
+
confidence=confidence,
|
| 127 |
+
explanation=self._explain(root_columns, len(covered), len(errors)),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def _has_upstream_selected_error(
|
| 131 |
+
self,
|
| 132 |
+
candidate: ErrorEvidence,
|
| 133 |
+
errors: list[ErrorEvidence] | tuple[ErrorEvidence, ...],
|
| 134 |
+
) -> bool:
|
| 135 |
+
"""Return whether another selected error causally precedes candidate."""
|
| 136 |
+
for other in errors:
|
| 137 |
+
if other.index == candidate.index:
|
| 138 |
+
continue
|
| 139 |
+
if other.column == candidate.column and other.index < candidate.index:
|
| 140 |
+
return True
|
| 141 |
+
if other.column != candidate.column and self._dag.is_reachable(
|
| 142 |
+
other.column, candidate.column
|
| 143 |
+
):
|
| 144 |
+
return True
|
| 145 |
+
return False
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def _explain(root_columns: list[str], covered_count: int, total_count: int) -> str:
|
| 149 |
+
"""Build a compact result explanation."""
|
| 150 |
+
if not root_columns:
|
| 151 |
+
return "No minimal roots were found."
|
| 152 |
+
joined = ", ".join(root_columns)
|
| 153 |
+
return f"Selected {joined} as minimal roots covering {covered_count}/{total_count} errors."
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def minimal_root_set(
|
| 157 |
+
errors: list[ErrorEvidence] | tuple[ErrorEvidence, ...], dag: CausalDAG
|
| 158 |
+
) -> RootCauseResult:
|
| 159 |
+
"""Convenience wrapper for CausalRootCauseAnalyzer.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
errors: Selected detected errors.
|
| 163 |
+
dag: Column-level causal DAG.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Minimal root-cause result.
|
| 167 |
+
"""
|
| 168 |
+
return CausalRootCauseAnalyzer(dag).analyze(errors)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def evidence_from_issue(index: int, issue: _IssueLike | dict[str, Any]) -> ErrorEvidence:
|
| 172 |
+
"""Build ErrorEvidence from an Issue-like object or dictionary.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
index: Error index to assign.
|
| 176 |
+
issue: Object or dictionary with row/column/type fields.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
ErrorEvidence instance.
|
| 180 |
+
"""
|
| 181 |
+
if isinstance(issue, dict):
|
| 182 |
+
return ErrorEvidence(
|
| 183 |
+
index=index,
|
| 184 |
+
row=int(issue.get("row", 0)),
|
| 185 |
+
column=str(issue.get("column", "")),
|
| 186 |
+
issue_type=str(issue.get("type", issue.get("issue_type", "unknown"))),
|
| 187 |
+
)
|
| 188 |
+
return ErrorEvidence(
|
| 189 |
+
index=index,
|
| 190 |
+
row=int(issue.row),
|
| 191 |
+
column=str(issue.column),
|
| 192 |
+
issue_type=str(issue.issue_type),
|
| 193 |
+
)
|
dataforge/cli/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Typer application entrypoint for
|
| 2 |
|
| 3 |
Each CLI subcommand is defined in its own module under ``dataforge.cli.*``
|
| 4 |
and registered here. The ``app`` object is the entry point referenced by
|
|
@@ -7,13 +7,16 @@ and registered here. The ``app`` object is the entry point referenced by
|
|
| 7 |
|
| 8 |
import typer
|
| 9 |
|
|
|
|
| 10 |
from dataforge.cli.bench import bench
|
| 11 |
from dataforge.cli.profile import profile
|
| 12 |
from dataforge.cli.repair import repair
|
|
|
|
| 13 |
from dataforge.cli.revert import revert
|
|
|
|
| 14 |
|
| 15 |
app: typer.Typer = typer.Typer(
|
| 16 |
-
help="
|
| 17 |
no_args_is_help=True,
|
| 18 |
)
|
| 19 |
|
|
@@ -28,15 +31,18 @@ def _main(
|
|
| 28 |
is_eager=True,
|
| 29 |
),
|
| 30 |
) -> None:
|
| 31 |
-
"""
|
| 32 |
if version:
|
| 33 |
from dataforge import __version__
|
| 34 |
|
| 35 |
-
typer.echo(f"
|
| 36 |
raise typer.Exit()
|
| 37 |
|
| 38 |
|
| 39 |
app.command(name="profile")(profile)
|
| 40 |
app.command(name="repair")(repair)
|
| 41 |
app.command(name="revert")(revert)
|
|
|
|
| 42 |
app.command(name="bench")(bench)
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Typer application entrypoint for DataForge15.
|
| 2 |
|
| 3 |
Each CLI subcommand is defined in its own module under ``dataforge.cli.*``
|
| 4 |
and registered here. The ``app`` object is the entry point referenced by
|
|
|
|
| 7 |
|
| 8 |
import typer
|
| 9 |
|
| 10 |
+
from dataforge.cli.audit import audit
|
| 11 |
from dataforge.cli.bench import bench
|
| 12 |
from dataforge.cli.profile import profile
|
| 13 |
from dataforge.cli.repair import repair
|
| 14 |
+
from dataforge.cli.release import release_app
|
| 15 |
from dataforge.cli.revert import revert
|
| 16 |
+
from dataforge.cli.watch import watch
|
| 17 |
|
| 18 |
app: typer.Typer = typer.Typer(
|
| 19 |
+
help="DataForge15 - AI-powered data-quality detection and repair.",
|
| 20 |
no_args_is_help=True,
|
| 21 |
)
|
| 22 |
|
|
|
|
| 31 |
is_eager=True,
|
| 32 |
),
|
| 33 |
) -> None:
|
| 34 |
+
"""DataForge15 - AI-powered data-quality detection and repair."""
|
| 35 |
if version:
|
| 36 |
from dataforge import __version__
|
| 37 |
|
| 38 |
+
typer.echo(f"dataforge15 {__version__}")
|
| 39 |
raise typer.Exit()
|
| 40 |
|
| 41 |
|
| 42 |
app.command(name="profile")(profile)
|
| 43 |
app.command(name="repair")(repair)
|
| 44 |
app.command(name="revert")(revert)
|
| 45 |
+
app.command(name="audit")(audit)
|
| 46 |
app.command(name="bench")(bench)
|
| 47 |
+
app.command(name="watch")(watch)
|
| 48 |
+
app.add_typer(release_app, name="release")
|
dataforge/cli/audit.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI subcommand: ``dataforge audit <txn_id>``."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Annotated
|
| 8 |
+
|
| 9 |
+
import typer
|
| 10 |
+
from rich.console import Console
|
| 11 |
+
from rich.panel import Panel
|
| 12 |
+
|
| 13 |
+
from dataforge.transactions import TransactionAuditVerdict, verify_transaction_log
|
| 14 |
+
|
| 15 |
+
_console = Console(stderr=True)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def audit(
|
| 19 |
+
txn_id: Annotated[
|
| 20 |
+
str,
|
| 21 |
+
typer.Argument(help="Transaction identifier to audit."),
|
| 22 |
+
],
|
| 23 |
+
search_root: Annotated[
|
| 24 |
+
Path | None,
|
| 25 |
+
typer.Option(
|
| 26 |
+
"--search-root",
|
| 27 |
+
help="Root directory used to locate the transaction log.",
|
| 28 |
+
exists=True,
|
| 29 |
+
file_okay=False,
|
| 30 |
+
dir_okay=True,
|
| 31 |
+
readable=True,
|
| 32 |
+
),
|
| 33 |
+
] = None,
|
| 34 |
+
log_path: Annotated[
|
| 35 |
+
Path | None,
|
| 36 |
+
typer.Option(
|
| 37 |
+
"--log-path",
|
| 38 |
+
help="Explicit JSONL transaction log path.",
|
| 39 |
+
exists=True,
|
| 40 |
+
file_okay=True,
|
| 41 |
+
dir_okay=False,
|
| 42 |
+
readable=True,
|
| 43 |
+
),
|
| 44 |
+
] = None,
|
| 45 |
+
json_output: Annotated[
|
| 46 |
+
bool,
|
| 47 |
+
typer.Option("--json", help="Print the audit report as JSON."),
|
| 48 |
+
] = False,
|
| 49 |
+
) -> None:
|
| 50 |
+
"""Verify a transaction log's local hash chain."""
|
| 51 |
+
report = verify_transaction_log(txn_id, log_path=log_path, search_root=search_root)
|
| 52 |
+
if json_output:
|
| 53 |
+
typer.echo(json.dumps(report.model_dump(mode="json"), indent=2, sort_keys=True))
|
| 54 |
+
else:
|
| 55 |
+
style = "green" if report.verdict == TransactionAuditVerdict.VERIFIED else "red"
|
| 56 |
+
body = (
|
| 57 |
+
f"Verdict: [bold]{report.verdict.value}[/bold]\n"
|
| 58 |
+
f"Transaction: {report.txn_id or txn_id}\n"
|
| 59 |
+
f"Events: {report.event_count}\n"
|
| 60 |
+
f"Head SHA-256: {report.head_sha256 or 'n/a'}"
|
| 61 |
+
)
|
| 62 |
+
if report.errors:
|
| 63 |
+
body += "\n\n" + "\n".join(f"- {error}" for error in report.errors)
|
| 64 |
+
_console.print(Panel(body, title="Transaction Audit", style=style))
|
| 65 |
+
|
| 66 |
+
if report.verdict == TransactionAuditVerdict.VERIFIED:
|
| 67 |
+
raise typer.Exit(code=0)
|
| 68 |
+
if report.verdict == TransactionAuditVerdict.LEGACY_UNVERIFIED:
|
| 69 |
+
raise typer.Exit(code=1)
|
| 70 |
+
raise typer.Exit(code=2)
|
dataforge/cli/bench.py
CHANGED
|
@@ -2,17 +2,18 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Annotated
|
| 7 |
|
| 8 |
import typer
|
| 9 |
from rich.console import Console
|
| 10 |
from rich.panel import Panel
|
| 11 |
from rich.table import Table
|
| 12 |
|
| 13 |
-
from dataforge.bench.runner import run_agent_comparison
|
| 14 |
-
|
| 15 |
_console = Console(stderr=True)
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def _parse_csv_list(raw_value: str) -> list[str]:
|
|
@@ -21,6 +22,16 @@ def _parse_csv_list(raw_value: str) -> list[str]:
|
|
| 21 |
return [value for value in values if value]
|
| 22 |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def bench(
|
| 25 |
methods: Annotated[
|
| 26 |
str,
|
|
@@ -54,10 +65,14 @@ def bench(
|
|
| 54 |
help="Where to write eval/results/agent_comparison.json.",
|
| 55 |
),
|
| 56 |
] = Path("eval/results/agent_comparison.json"),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
) -> None:
|
| 58 |
"""Run real-world benchmark methods across cached benchmark datasets."""
|
| 59 |
try:
|
| 60 |
-
output =
|
| 61 |
methods=_parse_csv_list(methods),
|
| 62 |
datasets=_parse_csv_list(datasets),
|
| 63 |
seeds=seeds,
|
|
@@ -74,6 +89,10 @@ def bench(
|
|
| 74 |
)
|
| 75 |
raise typer.Exit(code=2) from exc
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
table = Table(title="DataForge Benchmark Summary")
|
| 78 |
table.add_column("Method")
|
| 79 |
table.add_column("Dataset")
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import json
|
| 6 |
+
from collections.abc import Callable
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import Annotated, Any
|
| 9 |
|
| 10 |
import typer
|
| 11 |
from rich.console import Console
|
| 12 |
from rich.panel import Panel
|
| 13 |
from rich.table import Table
|
| 14 |
|
|
|
|
|
|
|
| 15 |
_console = Console(stderr=True)
|
| 16 |
+
run_agent_comparison: Callable[..., Any] | None = None
|
| 17 |
|
| 18 |
|
| 19 |
def _parse_csv_list(raw_value: str) -> list[str]:
|
|
|
|
| 22 |
return [value for value in values if value]
|
| 23 |
|
| 24 |
|
| 25 |
+
def _runner() -> Callable[..., Any]:
|
| 26 |
+
"""Load the benchmark runner lazily so core CLI imports stay lightweight."""
|
| 27 |
+
global run_agent_comparison
|
| 28 |
+
if run_agent_comparison is None:
|
| 29 |
+
from dataforge.bench.runner import run_agent_comparison as loaded_runner
|
| 30 |
+
|
| 31 |
+
run_agent_comparison = loaded_runner
|
| 32 |
+
return run_agent_comparison
|
| 33 |
+
|
| 34 |
+
|
| 35 |
def bench(
|
| 36 |
methods: Annotated[
|
| 37 |
str,
|
|
|
|
| 65 |
help="Where to write eval/results/agent_comparison.json.",
|
| 66 |
),
|
| 67 |
] = Path("eval/results/agent_comparison.json"),
|
| 68 |
+
json_output: Annotated[
|
| 69 |
+
bool,
|
| 70 |
+
typer.Option("--json", help="Print benchmark results as JSON."),
|
| 71 |
+
] = False,
|
| 72 |
) -> None:
|
| 73 |
"""Run real-world benchmark methods across cached benchmark datasets."""
|
| 74 |
try:
|
| 75 |
+
output = _runner()(
|
| 76 |
methods=_parse_csv_list(methods),
|
| 77 |
datasets=_parse_csv_list(datasets),
|
| 78 |
seeds=seeds,
|
|
|
|
| 89 |
)
|
| 90 |
raise typer.Exit(code=2) from exc
|
| 91 |
|
| 92 |
+
if json_output:
|
| 93 |
+
typer.echo(json.dumps(output.model_dump(mode="json"), indent=2, sort_keys=True))
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
table = Table(title="DataForge Benchmark Summary")
|
| 97 |
table.add_column("Method")
|
| 98 |
table.add_column("Dataset")
|
dataforge/cli/common.py
CHANGED
|
@@ -3,13 +3,14 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from collections.abc import Iterable
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import cast
|
| 8 |
|
| 9 |
-
import pandas as pd
|
| 10 |
import typer
|
| 11 |
import yaml
|
| 12 |
|
|
|
|
| 13 |
from dataforge.verifier.schema import (
|
| 14 |
AggregateDependency,
|
| 15 |
AggregateLiteral,
|
|
@@ -18,6 +19,27 @@ from dataforge.verifier.schema import (
|
|
| 18 |
Schema,
|
| 19 |
)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def schema_from_mapping(raw_mapping: object) -> Schema:
|
| 23 |
"""Build a Schema from a raw YAML mapping-like payload.
|
|
@@ -149,13 +171,13 @@ def load_schema(schema_path: Path) -> Schema:
|
|
| 149 |
return schema_from_mapping(raw)
|
| 150 |
|
| 151 |
|
| 152 |
-
def read_csv(path: Path) ->
|
| 153 |
"""Read a CSV using conservative string-preserving defaults.
|
| 154 |
|
| 155 |
Args:
|
| 156 |
path: CSV path.
|
| 157 |
|
| 158 |
Returns:
|
| 159 |
-
A
|
| 160 |
"""
|
| 161 |
-
return
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from collections.abc import Iterable
|
| 6 |
+
from importlib import resources
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import cast
|
| 9 |
|
|
|
|
| 10 |
import typer
|
| 11 |
import yaml
|
| 12 |
|
| 13 |
+
from dataforge.table import Table, read_csv as read_table_csv
|
| 14 |
from dataforge.verifier.schema import (
|
| 15 |
AggregateDependency,
|
| 16 |
AggregateLiteral,
|
|
|
|
| 19 |
Schema,
|
| 20 |
)
|
| 21 |
|
| 22 |
+
_PACKAGED_DEMO_FIXTURES = {
|
| 23 |
+
"fixtures/hospital_10rows.csv": "fixtures/hospital_10rows.csv",
|
| 24 |
+
"fixtures/hospital_schema.yaml": "fixtures/hospital_schema.yaml",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def resolve_cli_path(path: Path) -> Path:
|
| 29 |
+
"""Resolve a user path, including DataForge's packaged demo fixture aliases."""
|
| 30 |
+
if path.exists():
|
| 31 |
+
return path
|
| 32 |
+
|
| 33 |
+
normalized = path.as_posix().replace("\\", "/").lstrip("./")
|
| 34 |
+
packaged_name = _PACKAGED_DEMO_FIXTURES.get(normalized)
|
| 35 |
+
if packaged_name is None:
|
| 36 |
+
return path
|
| 37 |
+
|
| 38 |
+
fixture = resources.files("dataforge").joinpath(packaged_name)
|
| 39 |
+
if not fixture.is_file():
|
| 40 |
+
return path
|
| 41 |
+
return Path(str(fixture))
|
| 42 |
+
|
| 43 |
|
| 44 |
def schema_from_mapping(raw_mapping: object) -> Schema:
|
| 45 |
"""Build a Schema from a raw YAML mapping-like payload.
|
|
|
|
| 171 |
return schema_from_mapping(raw)
|
| 172 |
|
| 173 |
|
| 174 |
+
def read_csv(path: Path) -> Table:
|
| 175 |
"""Read a CSV using conservative string-preserving defaults.
|
| 176 |
|
| 177 |
Args:
|
| 178 |
path: CSV path.
|
| 179 |
|
| 180 |
Returns:
|
| 181 |
+
A string-preserving DataForge table.
|
| 182 |
"""
|
| 183 |
+
return read_table_csv(path)
|
dataforge/cli/profile.py
CHANGED
|
@@ -1,31 +1,46 @@
|
|
| 1 |
"""CLI subcommand: ``dataforge profile <path> [--schema <yaml>]``.
|
| 2 |
|
| 3 |
Reads a CSV file, runs all detectors, and renders detected issues as a
|
| 4 |
-
rich-formatted terminal table.
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
|
|
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
-
from typing import Annotated
|
| 11 |
|
| 12 |
import typer
|
| 13 |
from rich.console import Console
|
| 14 |
|
| 15 |
-
from dataforge.cli.common import load_schema, read_csv
|
| 16 |
from dataforge.detectors import run_all_detectors
|
| 17 |
-
from dataforge.detectors.base import Schema, Severity
|
| 18 |
from dataforge.ui.profile_view import render_profile_table
|
| 19 |
|
| 20 |
_console = Console(stderr=True)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def profile(
|
| 24 |
path: Annotated[
|
| 25 |
Path,
|
| 26 |
typer.Argument(
|
| 27 |
-
exists=True,
|
| 28 |
-
readable=True,
|
| 29 |
help="Path to the CSV file to profile.",
|
| 30 |
),
|
| 31 |
],
|
|
@@ -33,22 +48,36 @@ def profile(
|
|
| 33 |
Path | None,
|
| 34 |
typer.Option(
|
| 35 |
"--schema",
|
| 36 |
-
exists=True,
|
| 37 |
-
readable=True,
|
| 38 |
help="Path to a YAML schema file with column types and FDs.",
|
| 39 |
),
|
| 40 |
] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
) -> None:
|
| 42 |
"""Profile a CSV file for data-quality issues.
|
| 43 |
|
| 44 |
Reads the CSV, runs all detectors (type_mismatch, decimal_shift,
|
| 45 |
fd_violation), and renders a rich-formatted table of detected issues.
|
| 46 |
|
| 47 |
-
Exit code 0
|
| 48 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# Load the CSV with dtype=str to avoid pandas type-coercion artifacts.
|
| 50 |
try:
|
| 51 |
-
df = read_csv(
|
| 52 |
except Exception as exc:
|
| 53 |
_console.print(f"[bold red]Error reading CSV:[/bold red] {exc}")
|
| 54 |
raise typer.Exit(code=2) from exc
|
|
@@ -56,16 +85,32 @@ def profile(
|
|
| 56 |
# Optionally load schema.
|
| 57 |
parsed_schema: Schema | None = None
|
| 58 |
if schema is not None:
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# Run all detectors.
|
| 62 |
issues = run_all_detectors(df, parsed_schema)
|
| 63 |
|
| 64 |
# Render the results.
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
has_unsafe = any(i.severity == Severity.UNSAFE for i in issues)
|
| 70 |
-
if has_unsafe:
|
| 71 |
raise typer.Exit(code=1)
|
|
|
|
| 1 |
"""CLI subcommand: ``dataforge profile <path> [--schema <yaml>]``.
|
| 2 |
|
| 3 |
Reads a CSV file, runs all detectors, and renders detected issues as a
|
| 4 |
+
rich-formatted terminal table. Diagnostics exit 0 by default; use
|
| 5 |
+
``--fail-on`` for CI gating.
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
+
import json
|
| 11 |
+
from collections.abc import Sequence
|
| 12 |
from pathlib import Path
|
| 13 |
+
from typing import Annotated, Literal
|
| 14 |
|
| 15 |
import typer
|
| 16 |
from rich.console import Console
|
| 17 |
|
| 18 |
+
from dataforge.cli.common import load_schema, read_csv, resolve_cli_path
|
| 19 |
from dataforge.detectors import run_all_detectors
|
| 20 |
+
from dataforge.detectors.base import Issue, Schema, Severity
|
| 21 |
from dataforge.ui.profile_view import render_profile_table
|
| 22 |
|
| 23 |
_console = Console(stderr=True)
|
| 24 |
|
| 25 |
+
FailOn = Literal["never", "unsafe", "review", "any"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _should_fail(issues: Sequence[Issue], fail_on: FailOn) -> bool:
|
| 29 |
+
"""Return whether profile findings should trip the requested CI gate."""
|
| 30 |
+
if fail_on == "never":
|
| 31 |
+
return False
|
| 32 |
+
if fail_on == "any":
|
| 33 |
+
return bool(issues)
|
| 34 |
+
severities = [issue.severity for issue in issues]
|
| 35 |
+
if fail_on == "unsafe":
|
| 36 |
+
return any(severity == Severity.UNSAFE for severity in severities)
|
| 37 |
+
return any(severity >= Severity.REVIEW for severity in severities)
|
| 38 |
+
|
| 39 |
|
| 40 |
def profile(
|
| 41 |
path: Annotated[
|
| 42 |
Path,
|
| 43 |
typer.Argument(
|
|
|
|
|
|
|
| 44 |
help="Path to the CSV file to profile.",
|
| 45 |
),
|
| 46 |
],
|
|
|
|
| 48 |
Path | None,
|
| 49 |
typer.Option(
|
| 50 |
"--schema",
|
|
|
|
|
|
|
| 51 |
help="Path to a YAML schema file with column types and FDs.",
|
| 52 |
),
|
| 53 |
] = None,
|
| 54 |
+
json_output: Annotated[
|
| 55 |
+
bool,
|
| 56 |
+
typer.Option("--json", help="Print profile results as JSON."),
|
| 57 |
+
] = False,
|
| 58 |
+
fail_on: Annotated[
|
| 59 |
+
FailOn,
|
| 60 |
+
typer.Option(
|
| 61 |
+
"--fail-on",
|
| 62 |
+
help="Exit 1 when findings meet this threshold: never, unsafe, review, any.",
|
| 63 |
+
),
|
| 64 |
+
] = "never",
|
| 65 |
) -> None:
|
| 66 |
"""Profile a CSV file for data-quality issues.
|
| 67 |
|
| 68 |
Reads the CSV, runs all detectors (type_mismatch, decimal_shift,
|
| 69 |
fd_violation), and renders a rich-formatted table of detected issues.
|
| 70 |
|
| 71 |
+
Exit code 0 unless ``--fail-on`` is set and matching findings are present.
|
| 72 |
"""
|
| 73 |
+
resolved_path = resolve_cli_path(path)
|
| 74 |
+
if not resolved_path.exists():
|
| 75 |
+
_console.print(f"[bold red]CSV file not found:[/bold red] {path}")
|
| 76 |
+
raise typer.Exit(code=2)
|
| 77 |
+
|
| 78 |
# Load the CSV with dtype=str to avoid pandas type-coercion artifacts.
|
| 79 |
try:
|
| 80 |
+
df = read_csv(resolved_path)
|
| 81 |
except Exception as exc:
|
| 82 |
_console.print(f"[bold red]Error reading CSV:[/bold red] {exc}")
|
| 83 |
raise typer.Exit(code=2) from exc
|
|
|
|
| 85 |
# Optionally load schema.
|
| 86 |
parsed_schema: Schema | None = None
|
| 87 |
if schema is not None:
|
| 88 |
+
resolved_schema = resolve_cli_path(schema)
|
| 89 |
+
if not resolved_schema.exists():
|
| 90 |
+
_console.print(f"[bold red]Schema file not found:[/bold red] {schema}")
|
| 91 |
+
raise typer.Exit(code=2)
|
| 92 |
+
parsed_schema = load_schema(resolved_schema)
|
| 93 |
|
| 94 |
# Run all detectors.
|
| 95 |
issues = run_all_detectors(df, parsed_schema)
|
| 96 |
|
| 97 |
# Render the results.
|
| 98 |
+
if json_output:
|
| 99 |
+
typer.echo(
|
| 100 |
+
json.dumps(
|
| 101 |
+
{
|
| 102 |
+
"path": str(resolved_path),
|
| 103 |
+
"issues_count": len(issues),
|
| 104 |
+
"fail_on": fail_on,
|
| 105 |
+
"issues": [issue.model_dump(mode="json") for issue in issues],
|
| 106 |
+
},
|
| 107 |
+
indent=2,
|
| 108 |
+
sort_keys=True,
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
output_console = Console()
|
| 113 |
+
render_profile_table(issues, output_console, file_path=str(resolved_path))
|
| 114 |
|
| 115 |
+
if _should_fail(issues, fail_on):
|
|
|
|
|
|
|
| 116 |
raise typer.Exit(code=1)
|
dataforge/cli/release.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI group for local release verification."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Annotated
|
| 8 |
+
|
| 9 |
+
import typer
|
| 10 |
+
|
| 11 |
+
from dataforge.release.doctor import DEFAULT_KAGGLE_CREDENTIALS, run_doctor
|
| 12 |
+
|
| 13 |
+
release_app = typer.Typer(help="Release verification utilities.", no_args_is_help=True)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@release_app.command(name="doctor")
|
| 17 |
+
def doctor(
|
| 18 |
+
json_output: Annotated[
|
| 19 |
+
bool,
|
| 20 |
+
typer.Option("--json", help="Print machine-readable JSON."),
|
| 21 |
+
] = False,
|
| 22 |
+
kaggle_credentials: Annotated[
|
| 23 |
+
Path,
|
| 24 |
+
typer.Option(
|
| 25 |
+
"--kaggle-credentials",
|
| 26 |
+
help="Path to Kaggle OAuth credentials.json. Legacy kaggle.json is never read.",
|
| 27 |
+
),
|
| 28 |
+
] = DEFAULT_KAGGLE_CREDENTIALS,
|
| 29 |
+
) -> None:
|
| 30 |
+
"""Verify local release/deploy auth without printing secrets."""
|
| 31 |
+
report = run_doctor(kaggle_credentials=kaggle_credentials)
|
| 32 |
+
if json_output:
|
| 33 |
+
typer.echo(json.dumps(report.to_dict(), indent=2, sort_keys=True))
|
| 34 |
+
else:
|
| 35 |
+
for check in report.checks:
|
| 36 |
+
status = "ok" if check.ok else "fail"
|
| 37 |
+
typer.echo(f"{status:4} {check.name}: {check.detail}")
|
| 38 |
+
raise typer.Exit(code=0 if report.ok else 2)
|
| 39 |
+
|
dataforge/cli/repair.py
CHANGED
|
@@ -2,32 +2,25 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
import
|
| 6 |
-
from datetime import UTC, datetime
|
| 7 |
from pathlib import Path
|
| 8 |
-
from typing import Annotated
|
| 9 |
|
| 10 |
-
import pandas as pd
|
| 11 |
import typer
|
| 12 |
from rich.console import Console
|
| 13 |
from rich.panel import Panel
|
| 14 |
|
| 15 |
-
from dataforge.cli.common import load_schema,
|
| 16 |
-
from dataforge.detectors import run_all_detectors
|
| 17 |
from dataforge.detectors.base import Issue, Schema
|
| 18 |
-
from dataforge.repairers import
|
| 19 |
-
from dataforge.
|
| 20 |
-
from dataforge.
|
| 21 |
-
from dataforge.transactions.log import (
|
| 22 |
-
append_applied_event,
|
| 23 |
-
append_created_transaction,
|
| 24 |
-
cache_dir_for,
|
| 25 |
-
sha256_bytes,
|
| 26 |
-
snapshot_path_for,
|
| 27 |
-
)
|
| 28 |
-
from dataforge.transactions.txn import CellFix, RepairTransaction, generate_txn_id
|
| 29 |
from dataforge.ui.repair_diff import render_repair_diff
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
_console = Console(stderr=True)
|
| 33 |
|
|
@@ -45,32 +38,19 @@ def apply_fixes_to_csv(path: Path, fixes: list[CellFix]) -> str:
|
|
| 45 |
Raises:
|
| 46 |
ValueError: If a fix references a missing row/column or stale old value.
|
| 47 |
"""
|
| 48 |
-
|
| 49 |
-
for fix in fixes:
|
| 50 |
-
if fix.operation != "update":
|
| 51 |
-
raise ValueError(f"Unsupported repair operation '{fix.operation}' for row {fix.row}.")
|
| 52 |
-
if fix.column not in df.columns:
|
| 53 |
-
raise ValueError(f"Column '{fix.column}' not found in '{path}'.")
|
| 54 |
-
if fix.row < 0 or fix.row >= len(df.index):
|
| 55 |
-
raise ValueError(f"Row {fix.row} is out of bounds for '{path}'.")
|
| 56 |
-
|
| 57 |
-
current_value = str(df.at[fix.row, fix.column])
|
| 58 |
-
if current_value != fix.old_value:
|
| 59 |
-
raise ValueError(
|
| 60 |
-
f"Refusing to apply stale fix for row {fix.row}, column '{fix.column}': "
|
| 61 |
-
f"expected '{fix.old_value}', found '{current_value}'."
|
| 62 |
-
)
|
| 63 |
-
df.at[fix.row, fix.column] = fix.new_value
|
| 64 |
|
| 65 |
-
|
| 66 |
-
return hashlib.sha256(path.read_bytes()).hexdigest()
|
| 67 |
|
| 68 |
|
| 69 |
def _resolve_schema(schema_path: Path | None) -> Schema | None:
|
| 70 |
"""Resolve an optional schema path into a parsed Schema."""
|
| 71 |
if schema_path is None:
|
| 72 |
return None
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def _print_error(message: str, *, hint: str | None = None) -> None:
|
|
@@ -94,157 +74,21 @@ def _propose_repairs(
|
|
| 94 |
confirm_escalations: bool,
|
| 95 |
interactive: bool,
|
| 96 |
) -> tuple[list[ProposedFix], list[list[RepairAttempt]]]:
|
| 97 |
-
"""
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
allow_llm=allow_llm,
|
| 101 |
model=model,
|
| 102 |
-
)
|
| 103 |
-
safety_filter = SafetyFilter()
|
| 104 |
-
verifier = SMTVerifier()
|
| 105 |
-
safety_context = SafetyContext(
|
| 106 |
allow_pii=allow_pii,
|
| 107 |
confirm_pii=confirm_pii,
|
| 108 |
confirm_escalations=confirm_escalations,
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
accepted_fixes: list[ProposedFix] = []
|
| 112 |
-
attempt_groups: list[list[RepairAttempt]] = []
|
| 113 |
-
|
| 114 |
-
for issue in issues:
|
| 115 |
-
attempts: list[RepairAttempt] = []
|
| 116 |
-
repairer = repairers.get(issue.issue_type)
|
| 117 |
-
if repairer is None:
|
| 118 |
-
attempts.append(
|
| 119 |
-
RepairAttempt(
|
| 120 |
-
issue=issue,
|
| 121 |
-
attempt_number=1,
|
| 122 |
-
status="attempted_not_fixed",
|
| 123 |
-
reason="No repairer is registered for this issue type.",
|
| 124 |
-
)
|
| 125 |
-
)
|
| 126 |
-
attempt_groups.append(attempts)
|
| 127 |
-
continue
|
| 128 |
-
|
| 129 |
-
accepted = False
|
| 130 |
-
retry_context = RetryContext(issue=issue)
|
| 131 |
-
for attempt_number in range(1, 4):
|
| 132 |
-
candidate = repairer.propose(issue, working_df, schema, retry_context=retry_context)
|
| 133 |
-
if candidate is None:
|
| 134 |
-
attempts.append(
|
| 135 |
-
RepairAttempt(
|
| 136 |
-
issue=issue,
|
| 137 |
-
attempt_number=attempt_number,
|
| 138 |
-
status="attempted_not_fixed",
|
| 139 |
-
reason="No repair proposal was available for this issue.",
|
| 140 |
-
)
|
| 141 |
-
)
|
| 142 |
-
break
|
| 143 |
-
|
| 144 |
-
preferred = safety_filter.choose_preferred([candidate], schema, safety_context)
|
| 145 |
-
safety_result = safety_filter.evaluate(preferred, schema, safety_context)
|
| 146 |
-
if safety_result.verdict == SafetyVerdict.ESCALATE and interactive:
|
| 147 |
-
safety_context, safety_result = _resolve_escalation(
|
| 148 |
-
preferred,
|
| 149 |
-
schema,
|
| 150 |
-
safety_context,
|
| 151 |
-
safety_filter,
|
| 152 |
-
safety_result,
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
if safety_result.verdict == SafetyVerdict.DENY:
|
| 156 |
-
attempts.append(
|
| 157 |
-
RepairAttempt(
|
| 158 |
-
issue=issue,
|
| 159 |
-
attempt_number=attempt_number,
|
| 160 |
-
fix=preferred,
|
| 161 |
-
status="denied",
|
| 162 |
-
reason=safety_result.reason,
|
| 163 |
-
)
|
| 164 |
-
)
|
| 165 |
-
retry_context = _build_retry_context(issue, attempts)
|
| 166 |
-
continue
|
| 167 |
-
|
| 168 |
-
if safety_result.verdict == SafetyVerdict.ESCALATE:
|
| 169 |
-
attempts.append(
|
| 170 |
-
RepairAttempt(
|
| 171 |
-
issue=issue,
|
| 172 |
-
attempt_number=attempt_number,
|
| 173 |
-
fix=preferred,
|
| 174 |
-
status="escalated",
|
| 175 |
-
reason=safety_result.reason,
|
| 176 |
-
)
|
| 177 |
-
)
|
| 178 |
-
break
|
| 179 |
-
|
| 180 |
-
verifier_result = verifier.verify(working_df, [preferred], schema)
|
| 181 |
-
if verifier_result.verdict == VerificationVerdict.ACCEPT:
|
| 182 |
-
accepted_fixes.append(preferred)
|
| 183 |
-
working_df.at[preferred.fix.row, preferred.fix.column] = preferred.fix.new_value
|
| 184 |
-
attempts.append(
|
| 185 |
-
RepairAttempt(
|
| 186 |
-
issue=issue,
|
| 187 |
-
attempt_number=attempt_number,
|
| 188 |
-
fix=preferred,
|
| 189 |
-
status="accepted",
|
| 190 |
-
reason=verifier_result.reason,
|
| 191 |
-
)
|
| 192 |
-
)
|
| 193 |
-
accepted = True
|
| 194 |
-
break
|
| 195 |
-
|
| 196 |
-
attempts.append(
|
| 197 |
-
RepairAttempt(
|
| 198 |
-
issue=issue,
|
| 199 |
-
attempt_number=attempt_number,
|
| 200 |
-
fix=preferred,
|
| 201 |
-
status=(
|
| 202 |
-
"rejected"
|
| 203 |
-
if verifier_result.verdict == VerificationVerdict.REJECT
|
| 204 |
-
else "unknown"
|
| 205 |
-
),
|
| 206 |
-
reason=verifier_result.reason,
|
| 207 |
-
unsat_core=verifier_result.unsat_core,
|
| 208 |
-
)
|
| 209 |
-
)
|
| 210 |
-
retry_context = _build_retry_context(issue, attempts)
|
| 211 |
-
|
| 212 |
-
if (
|
| 213 |
-
not accepted
|
| 214 |
-
and attempts
|
| 215 |
-
and attempts[-1].status not in {"attempted_not_fixed", "escalated"}
|
| 216 |
-
):
|
| 217 |
-
last_reason = attempts[-1].reason
|
| 218 |
-
attempts[-1] = attempts[-1].model_copy(
|
| 219 |
-
update={
|
| 220 |
-
"status": "attempted_not_fixed",
|
| 221 |
-
"reason": (
|
| 222 |
-
f"Issue was attempted but not fixed after {len(attempts)} attempt(s). "
|
| 223 |
-
f"Last failure: {last_reason}"
|
| 224 |
-
),
|
| 225 |
-
}
|
| 226 |
-
)
|
| 227 |
-
attempt_groups.append(attempts)
|
| 228 |
-
|
| 229 |
-
return accepted_fixes, attempt_groups
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
def _build_retry_context(issue: Issue, attempts: list[RepairAttempt]) -> RetryContext:
|
| 233 |
-
"""Build retry hints from previous failed attempts."""
|
| 234 |
-
rejected_values = frozenset(
|
| 235 |
-
attempt.fix.fix.new_value
|
| 236 |
-
for attempt in attempts
|
| 237 |
-
if attempt.fix is not None and attempt.status in {"denied", "rejected", "unknown"}
|
| 238 |
-
)
|
| 239 |
-
hints: list[str] = []
|
| 240 |
-
for attempt in attempts:
|
| 241 |
-
hints.append(attempt.reason)
|
| 242 |
-
hints.extend(attempt.unsat_core)
|
| 243 |
-
return RetryContext(
|
| 244 |
-
issue=issue,
|
| 245 |
-
previous_attempts=tuple(attempts),
|
| 246 |
-
rejected_values=rejected_values,
|
| 247 |
-
hints=tuple(hints),
|
| 248 |
)
|
| 249 |
|
| 250 |
|
|
@@ -309,45 +153,46 @@ def _render_attempt_summary(
|
|
| 309 |
return len(failed_groups)
|
| 310 |
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
def _apply_transaction(
|
| 313 |
path: Path,
|
| 314 |
fixes: list[ProposedFix],
|
| 315 |
source_bytes: bytes,
|
| 316 |
) -> str:
|
| 317 |
-
"""
|
| 318 |
-
|
| 319 |
-
txn_id = generate_txn_id()
|
| 320 |
-
snapshot_path = snapshot_path_for(resolved_path, txn_id)
|
| 321 |
-
snapshot_path.parent.mkdir(parents=True, exist_ok=True)
|
| 322 |
-
snapshot_path.write_bytes(source_bytes)
|
| 323 |
-
|
| 324 |
-
transaction = RepairTransaction(
|
| 325 |
-
txn_id=txn_id,
|
| 326 |
-
created_at=datetime.now(UTC),
|
| 327 |
-
source_path=str(resolved_path),
|
| 328 |
-
source_sha256=sha256_bytes(source_bytes),
|
| 329 |
-
source_snapshot_path=str(snapshot_path.resolve()),
|
| 330 |
-
fixes=[proposal.fix for proposal in fixes],
|
| 331 |
-
applied=False,
|
| 332 |
-
)
|
| 333 |
-
log_path = append_created_transaction(transaction)
|
| 334 |
|
| 335 |
-
|
| 336 |
-
post_sha256 = apply_fixes_to_csv(path, [proposal.fix for proposal in fixes])
|
| 337 |
-
append_applied_event(log_path, txn_id, post_sha256=post_sha256)
|
| 338 |
-
except Exception:
|
| 339 |
-
path.write_bytes(source_bytes)
|
| 340 |
-
raise
|
| 341 |
-
|
| 342 |
-
return txn_id
|
| 343 |
|
| 344 |
|
| 345 |
def repair(
|
| 346 |
path: Annotated[
|
| 347 |
Path,
|
| 348 |
typer.Argument(
|
| 349 |
-
exists=True,
|
| 350 |
-
readable=True,
|
| 351 |
help="Path to the CSV file to repair.",
|
| 352 |
),
|
| 353 |
],
|
|
@@ -355,8 +200,6 @@ def repair(
|
|
| 355 |
Path | None,
|
| 356 |
typer.Option(
|
| 357 |
"--schema",
|
| 358 |
-
exists=True,
|
| 359 |
-
readable=True,
|
| 360 |
help="Path to a YAML schema file with column types and FDs.",
|
| 361 |
),
|
| 362 |
] = None,
|
|
@@ -400,6 +243,10 @@ def repair(
|
|
| 400 |
str,
|
| 401 |
typer.Option("--llm-model", help="Model name for fd_violation LLM fallback."),
|
| 402 |
] = "gemini-2.0-flash",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
) -> None:
|
| 404 |
"""Detect, propose, and optionally apply reversible repairs to a CSV."""
|
| 405 |
if dry_run == apply:
|
|
@@ -410,58 +257,66 @@ def repair(
|
|
| 410 |
raise typer.Exit(code=2)
|
| 411 |
|
| 412 |
try:
|
|
|
|
|
|
|
|
|
|
| 413 |
parsed_schema = _resolve_schema(schema)
|
| 414 |
-
df = read_csv(path)
|
| 415 |
except Exception as exc:
|
| 416 |
_print_error(str(exc))
|
| 417 |
raise typer.Exit(code=2) from exc
|
| 418 |
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
raise typer.Exit(code=1)
|
| 439 |
|
| 440 |
if dry_run:
|
| 441 |
-
raise typer.Exit(code=0 if
|
| 442 |
-
|
| 443 |
-
if not accepted_fixes:
|
| 444 |
-
raise typer.Exit(code=1)
|
| 445 |
|
| 446 |
-
|
| 447 |
-
if batch_safety.verdict != SafetyVerdict.ALLOW:
|
| 448 |
-
_print_error(batch_safety.reason)
|
| 449 |
raise typer.Exit(code=1)
|
| 450 |
|
| 451 |
-
source_bytes = path.read_bytes()
|
| 452 |
-
try:
|
| 453 |
-
txn_id = _apply_transaction(path, accepted_fixes, source_bytes)
|
| 454 |
-
except Exception as exc:
|
| 455 |
-
_print_error(
|
| 456 |
-
f"Failed to apply repairs: {exc}",
|
| 457 |
-
hint="The source file was restored to its pre-apply bytes.",
|
| 458 |
-
)
|
| 459 |
-
raise typer.Exit(code=1) from exc
|
| 460 |
-
|
| 461 |
output_console.print(
|
| 462 |
Panel(
|
| 463 |
-
f"[green]Applied {len(
|
| 464 |
-
f"Transaction ID: [bold]{txn_id}[/bold]",
|
| 465 |
title="Repair Applied",
|
| 466 |
style="green",
|
| 467 |
)
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import json
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING, Annotated
|
| 8 |
|
|
|
|
| 9 |
import typer
|
| 10 |
from rich.console import Console
|
| 11 |
from rich.panel import Panel
|
| 12 |
|
| 13 |
+
from dataforge.cli.common import load_schema, resolve_cli_path
|
|
|
|
| 14 |
from dataforge.detectors.base import Issue, Schema
|
| 15 |
+
from dataforge.repairers.base import ProposedFix, RepairAttempt
|
| 16 |
+
from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult
|
| 17 |
+
from dataforge.transactions.txn import CellFix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from dataforge.ui.repair_diff import render_repair_diff
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
import pandas as pd
|
| 22 |
+
|
| 23 |
+
from dataforge.engine.repair import RepairPipelineResult
|
| 24 |
|
| 25 |
_console = Console(stderr=True)
|
| 26 |
|
|
|
|
| 38 |
Raises:
|
| 39 |
ValueError: If a fix references a missing row/column or stale old value.
|
| 40 |
"""
|
| 41 |
+
from dataforge.engine.repair import apply_fixes_to_csv as engine_apply_fixes_to_csv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
return engine_apply_fixes_to_csv(path, fixes)
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def _resolve_schema(schema_path: Path | None) -> Schema | None:
|
| 47 |
"""Resolve an optional schema path into a parsed Schema."""
|
| 48 |
if schema_path is None:
|
| 49 |
return None
|
| 50 |
+
resolved_schema = resolve_cli_path(schema_path)
|
| 51 |
+
if not resolved_schema.exists():
|
| 52 |
+
raise typer.BadParameter(f"Schema file '{schema_path}' does not exist.")
|
| 53 |
+
return load_schema(resolved_schema)
|
| 54 |
|
| 55 |
|
| 56 |
def _print_error(message: str, *, hint: str | None = None) -> None:
|
|
|
|
| 74 |
confirm_escalations: bool,
|
| 75 |
interactive: bool,
|
| 76 |
) -> tuple[list[ProposedFix], list[list[RepairAttempt]]]:
|
| 77 |
+
"""Compatibility wrapper around the shared repair engine proposal stage."""
|
| 78 |
+
from dataforge.engine.repair import propose_repairs as engine_propose_repairs
|
| 79 |
+
|
| 80 |
+
return engine_propose_repairs(
|
| 81 |
+
issues,
|
| 82 |
+
path,
|
| 83 |
+
working_df,
|
| 84 |
+
schema,
|
| 85 |
allow_llm=allow_llm,
|
| 86 |
model=model,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
allow_pii=allow_pii,
|
| 88 |
confirm_pii=confirm_pii,
|
| 89 |
confirm_escalations=confirm_escalations,
|
| 90 |
+
interactive=interactive,
|
| 91 |
+
escalation_resolver=_resolve_escalation,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
)
|
| 93 |
|
| 94 |
|
|
|
|
| 153 |
return len(failed_groups)
|
| 154 |
|
| 155 |
|
| 156 |
+
def _render_failure_summary(result: RepairPipelineResult, console: Console) -> int:
|
| 157 |
+
"""Render a summary for issues that the shared engine could not repair."""
|
| 158 |
+
if not result.failures:
|
| 159 |
+
return 0
|
| 160 |
+
|
| 161 |
+
console.print("[bold yellow]Attempted But Not Fixed[/bold yellow]")
|
| 162 |
+
for failure in result.failures:
|
| 163 |
+
prefix = ""
|
| 164 |
+
if any(label.startswith("fd::") for label in failure.unsat_core):
|
| 165 |
+
prefix = "functional dependency rejection - "
|
| 166 |
+
elif any(label.startswith("domain::") for label in failure.unsat_core):
|
| 167 |
+
prefix = "domain bound rejection - "
|
| 168 |
+
console.print(
|
| 169 |
+
f"{failure.issue_type} at {failure.row}:{failure.column} "
|
| 170 |
+
f"after {failure.attempt_count} attempt(s): {prefix}{failure.reason}",
|
| 171 |
+
overflow="fold",
|
| 172 |
+
)
|
| 173 |
+
return len(result.failures)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _json_result(result: RepairPipelineResult) -> str:
|
| 177 |
+
"""Serialize a repair result for CLI/MCP/CI consumers."""
|
| 178 |
+
return json.dumps(result.model_dump(mode="json"), indent=2, sort_keys=True)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
def _apply_transaction(
|
| 182 |
path: Path,
|
| 183 |
fixes: list[ProposedFix],
|
| 184 |
source_bytes: bytes,
|
| 185 |
) -> str:
|
| 186 |
+
"""Compatibility wrapper around the shared repair engine transaction path."""
|
| 187 |
+
from dataforge.engine.repair import apply_transaction as engine_apply_transaction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
+
return engine_apply_transaction(path, fixes, source_bytes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
def repair(
|
| 193 |
path: Annotated[
|
| 194 |
Path,
|
| 195 |
typer.Argument(
|
|
|
|
|
|
|
| 196 |
help="Path to the CSV file to repair.",
|
| 197 |
),
|
| 198 |
],
|
|
|
|
| 200 |
Path | None,
|
| 201 |
typer.Option(
|
| 202 |
"--schema",
|
|
|
|
|
|
|
| 203 |
help="Path to a YAML schema file with column types and FDs.",
|
| 204 |
),
|
| 205 |
] = None,
|
|
|
|
| 243 |
str,
|
| 244 |
typer.Option("--llm-model", help="Model name for fd_violation LLM fallback."),
|
| 245 |
] = "gemini-2.0-flash",
|
| 246 |
+
json_output: Annotated[
|
| 247 |
+
bool,
|
| 248 |
+
typer.Option("--json", help="Print repair result as JSON."),
|
| 249 |
+
] = False,
|
| 250 |
) -> None:
|
| 251 |
"""Detect, propose, and optionally apply reversible repairs to a CSV."""
|
| 252 |
if dry_run == apply:
|
|
|
|
| 257 |
raise typer.Exit(code=2)
|
| 258 |
|
| 259 |
try:
|
| 260 |
+
resolved_path = resolve_cli_path(path)
|
| 261 |
+
if not resolved_path.exists():
|
| 262 |
+
raise typer.BadParameter(f"CSV file '{path}' does not exist.")
|
| 263 |
parsed_schema = _resolve_schema(schema)
|
|
|
|
| 264 |
except Exception as exc:
|
| 265 |
_print_error(str(exc))
|
| 266 |
raise typer.Exit(code=2) from exc
|
| 267 |
|
| 268 |
+
try:
|
| 269 |
+
from dataforge.engine.repair import RepairPipelineRequest, run_repair_pipeline
|
| 270 |
+
|
| 271 |
+
result = run_repair_pipeline(
|
| 272 |
+
RepairPipelineRequest(
|
| 273 |
+
source_path=resolved_path,
|
| 274 |
+
mode="apply" if apply else "dry_run",
|
| 275 |
+
schema=parsed_schema,
|
| 276 |
+
allow_llm=allow_llm,
|
| 277 |
+
model=llm_model,
|
| 278 |
+
allow_pii=allow_pii,
|
| 279 |
+
confirm_pii=confirm_pii,
|
| 280 |
+
confirm_escalations=confirm_escalations,
|
| 281 |
+
interactive=apply,
|
| 282 |
+
)
|
| 283 |
+
)
|
| 284 |
+
except Exception as exc:
|
| 285 |
+
_print_error(
|
| 286 |
+
f"Failed to apply repairs: {exc}" if apply else f"Failed to repair: {exc}",
|
| 287 |
+
hint="The source file was restored to its pre-apply bytes." if apply else None,
|
| 288 |
+
)
|
| 289 |
+
raise typer.Exit(code=1 if apply else 2) from exc
|
| 290 |
|
| 291 |
+
if json_output:
|
| 292 |
+
typer.echo(_json_result(result))
|
| 293 |
+
raise typer.Exit(code=0 if result.fixes else 1)
|
| 294 |
|
| 295 |
+
output_console = Console()
|
| 296 |
+
render_repair_diff(result.fixes, output_console, file_path=str(resolved_path))
|
| 297 |
+
failed_issue_count = _render_failure_summary(result, output_console)
|
| 298 |
+
|
| 299 |
+
if not result.fixes and failed_issue_count == 0:
|
| 300 |
+
if result.receipt.reason != "No accepted fixes were produced.":
|
| 301 |
+
output_console.print(
|
| 302 |
+
Panel(
|
| 303 |
+
f"[yellow]{result.receipt.reason}[/yellow]",
|
| 304 |
+
title="Repair Summary",
|
| 305 |
+
style="yellow",
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
raise typer.Exit(code=1)
|
| 309 |
|
| 310 |
if dry_run:
|
| 311 |
+
raise typer.Exit(code=0 if result.fixes else 1)
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
+
if not result.fixes or not result.receipt.applied:
|
|
|
|
|
|
|
| 314 |
raise typer.Exit(code=1)
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
output_console.print(
|
| 317 |
Panel(
|
| 318 |
+
f"[green]Applied {len(result.fixes)} fix(es).[/green]\n"
|
| 319 |
+
f"Transaction ID: [bold]{result.receipt.txn_id}[/bold]",
|
| 320 |
title="Repair Applied",
|
| 321 |
style="green",
|
| 322 |
)
|
dataforge/cli/watch.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI subcommand: ``dataforge watch``."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Annotated, Literal
|
| 9 |
+
|
| 10 |
+
import typer
|
| 11 |
+
from rich.console import Console
|
| 12 |
+
from rich.panel import Panel
|
| 13 |
+
|
| 14 |
+
from dataforge.cli.common import load_schema, read_csv, resolve_cli_path
|
| 15 |
+
from dataforge.detectors import run_all_detectors
|
| 16 |
+
from dataforge.detectors.base import Schema
|
| 17 |
+
from dataforge.ui.profile_view import render_profile_table
|
| 18 |
+
from dataforge.ui.repair_diff import render_repair_diff
|
| 19 |
+
|
| 20 |
+
_console = Console(stderr=True)
|
| 21 |
+
|
| 22 |
+
WatchAction = Literal["profile", "repair"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _load_optional_schema(schema_path: Path | None) -> Schema | None:
|
| 26 |
+
if schema_path is None:
|
| 27 |
+
return None
|
| 28 |
+
resolved_schema = resolve_cli_path(schema_path)
|
| 29 |
+
if not resolved_schema.exists():
|
| 30 |
+
raise typer.BadParameter(f"Schema file '{schema_path}' does not exist.")
|
| 31 |
+
return load_schema(resolved_schema)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _profile_once(path: Path, schema: Schema | None, json_output: bool) -> None:
|
| 35 |
+
df = read_csv(path)
|
| 36 |
+
issues = run_all_detectors(df, schema)
|
| 37 |
+
if json_output:
|
| 38 |
+
typer.echo(
|
| 39 |
+
json.dumps(
|
| 40 |
+
{
|
| 41 |
+
"event": "profile",
|
| 42 |
+
"path": str(path),
|
| 43 |
+
"issues_count": len(issues),
|
| 44 |
+
"issues": [issue.model_dump(mode="json") for issue in issues],
|
| 45 |
+
},
|
| 46 |
+
indent=2,
|
| 47 |
+
sort_keys=True,
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
return
|
| 51 |
+
render_profile_table(issues, Console(), file_path=str(path))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _repair_once(path: Path, schema: Schema | None, apply: bool, json_output: bool) -> None:
|
| 55 |
+
from dataforge.engine.repair import RepairPipelineRequest, run_repair_pipeline
|
| 56 |
+
|
| 57 |
+
result = run_repair_pipeline(
|
| 58 |
+
RepairPipelineRequest(
|
| 59 |
+
source_path=path,
|
| 60 |
+
mode="apply" if apply else "dry_run",
|
| 61 |
+
schema=schema,
|
| 62 |
+
interactive=False,
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
if json_output:
|
| 66 |
+
payload = result.model_dump(mode="json")
|
| 67 |
+
payload["event"] = "repair"
|
| 68 |
+
typer.echo(json.dumps(payload, indent=2, sort_keys=True))
|
| 69 |
+
return
|
| 70 |
+
render_repair_diff(result.fixes, Console(), file_path=str(path))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _run_once(path: Path, schema: Schema | None, action: WatchAction, apply: bool, json: bool) -> None:
|
| 74 |
+
if action == "repair":
|
| 75 |
+
_repair_once(path, schema, apply, json)
|
| 76 |
+
else:
|
| 77 |
+
_profile_once(path, schema, json)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def watch(
|
| 81 |
+
path: Annotated[
|
| 82 |
+
Path,
|
| 83 |
+
typer.Argument(help="CSV or dbt artifact path to watch."),
|
| 84 |
+
],
|
| 85 |
+
schema: Annotated[
|
| 86 |
+
Path | None,
|
| 87 |
+
typer.Option("--schema", help="Path to a YAML schema file with column types and FDs."),
|
| 88 |
+
] = None,
|
| 89 |
+
action: Annotated[
|
| 90 |
+
WatchAction,
|
| 91 |
+
typer.Option("--action", help="Action to run when the file changes: profile or repair."),
|
| 92 |
+
] = "profile",
|
| 93 |
+
apply: Annotated[
|
| 94 |
+
bool,
|
| 95 |
+
typer.Option("--apply", help="Apply repairs on change. Defaults to dry-run repair."),
|
| 96 |
+
] = False,
|
| 97 |
+
interval: Annotated[
|
| 98 |
+
float,
|
| 99 |
+
typer.Option("--interval", min=0.1, help="Polling interval in seconds."),
|
| 100 |
+
] = 2.0,
|
| 101 |
+
once: Annotated[
|
| 102 |
+
bool,
|
| 103 |
+
typer.Option("--once", help="Run once and exit, useful for CI acceptance."),
|
| 104 |
+
] = False,
|
| 105 |
+
json_output: Annotated[
|
| 106 |
+
bool,
|
| 107 |
+
typer.Option("--json", help="Print watch events as JSON."),
|
| 108 |
+
] = False,
|
| 109 |
+
) -> None:
|
| 110 |
+
"""Poll a path and rerun profile or repair when it changes."""
|
| 111 |
+
resolved_path = resolve_cli_path(path)
|
| 112 |
+
if not resolved_path.exists():
|
| 113 |
+
_console.print(f"[bold red]Watch path not found:[/bold red] {path}")
|
| 114 |
+
raise typer.Exit(code=2)
|
| 115 |
+
parsed_schema = _load_optional_schema(schema)
|
| 116 |
+
|
| 117 |
+
if apply and action != "repair":
|
| 118 |
+
_console.print(
|
| 119 |
+
Panel(
|
| 120 |
+
"--apply is only valid with --action repair.",
|
| 121 |
+
title="Watch Error",
|
| 122 |
+
style="red",
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
raise typer.Exit(code=2)
|
| 126 |
+
|
| 127 |
+
_run_once(resolved_path, parsed_schema, action, apply, json_output)
|
| 128 |
+
if once:
|
| 129 |
+
return
|
| 130 |
+
|
| 131 |
+
last_mtime = resolved_path.stat().st_mtime_ns
|
| 132 |
+
while True:
|
| 133 |
+
time.sleep(interval)
|
| 134 |
+
try:
|
| 135 |
+
current_mtime = resolved_path.stat().st_mtime_ns
|
| 136 |
+
except FileNotFoundError:
|
| 137 |
+
_console.print(f"[bold red]Watch path disappeared:[/bold red] {resolved_path}")
|
| 138 |
+
raise typer.Exit(code=2) from None
|
| 139 |
+
if current_mtime == last_mtime:
|
| 140 |
+
continue
|
| 141 |
+
last_mtime = current_mtime
|
| 142 |
+
_run_once(resolved_path, parsed_schema, action, apply, json_output)
|
dataforge/datasets/embedded/hospital/clean.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,age,admission_date,name
|
| 2 |
+
1,30,2020-01-01,Alice
|
| 3 |
+
2,45,2020-01-02,Bob
|
| 4 |
+
3,30,2020-01-03,Carol
|
| 5 |
+
4,29,2020-01-04,Dave
|
| 6 |
+
5,35,2020-01-05,Eve
|
| 7 |
+
6,51,2020-01-06,Frank
|
| 8 |
+
7,40,2020-01-07,Grace
|
| 9 |
+
8,35,2020-01-08,Heidi
|
| 10 |
+
9,28,2020-01-09,Ivan
|
| 11 |
+
10,60,2020-01-10,Judy
|
dataforge/datasets/embedded/hospital/dirty.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,age,admission_date,name
|
| 2 |
+
1,30,2020-01-01,Alice
|
| 3 |
+
2,45,2020-01-02,Bob
|
| 4 |
+
3,N/A,2020-01-03,Carol
|
| 5 |
+
4,29,2020-01-04,Dave
|
| 6 |
+
5,null,2020-01-05,Eve
|
| 7 |
+
6,51,2020-01-06,Frank
|
| 8 |
+
7,40,2020-01-07,Grace
|
| 9 |
+
8,35,2020-01-08,Heidi
|
| 10 |
+
9,28,2020-01-09,Ivan
|
| 11 |
+
10,60,2020-01-10,Judy
|
dataforge/datasets/real_world.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
| 7 |
|
|
@@ -16,6 +18,9 @@ class DatasetDownloadError(RuntimeError):
|
|
| 16 |
"""Raised when a real-world dataset cannot be downloaded or loaded from cache."""
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
class GroundTruthCell(BaseModel):
|
| 20 |
"""Single cell-level dirty-to-clean correction used for benchmark scoring."""
|
| 21 |
|
|
@@ -57,7 +62,11 @@ def _read_cached_csv(path: Path) -> pd.DataFrame:
|
|
| 57 |
|
| 58 |
def _download_bytes(url: str) -> bytes:
|
| 59 |
"""Download raw CSV bytes from an upstream source URL."""
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
response = client.get(url)
|
| 62 |
response.raise_for_status()
|
| 63 |
return response.content
|
|
@@ -67,8 +76,19 @@ def _download_to_cache(metadata: DatasetMetadata, dataset_dir: Path) -> None:
|
|
| 67 |
"""Download dirty/clean CSV files into the dataset cache directory."""
|
| 68 |
dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 69 |
dirty_url, clean_url = metadata.source_urls
|
|
|
|
| 70 |
(dataset_dir / "dirty.csv").write_bytes(_download_bytes(dirty_url))
|
| 71 |
(dataset_dir / "clean.csv").write_bytes(_download_bytes(clean_url))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
def _manual_download_message(metadata: DatasetMetadata, dataset_dir: Path, cause: Exception) -> str:
|
|
@@ -153,16 +173,26 @@ def load_real_world_dataset(
|
|
| 153 |
dirty_path = dataset_dir / "dirty.csv"
|
| 154 |
clean_path = dataset_dir / "clean.csv"
|
| 155 |
|
|
|
|
|
|
|
|
|
|
| 156 |
if not dirty_path.exists() or not clean_path.exists():
|
|
|
|
| 157 |
try:
|
| 158 |
_download_to_cache(metadata, dataset_dir)
|
| 159 |
except Exception as exc: # pragma: no cover - exercised through tests via monkeypatch
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
if len(dirty_df.index) != len(clean_df.index):
|
| 168 |
raise ValueError(f"Dataset '{name}' dirty/clean row counts do not match.")
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
| 18 |
"""Raised when a real-world dataset cannot be downloaded or loaded from cache."""
|
| 19 |
|
| 20 |
|
| 21 |
+
_LOGGER = logging.getLogger("dataforge.datasets.real_world")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
class GroundTruthCell(BaseModel):
|
| 25 |
"""Single cell-level dirty-to-clean correction used for benchmark scoring."""
|
| 26 |
|
|
|
|
| 62 |
|
| 63 |
def _download_bytes(url: str) -> bytes:
|
| 64 |
"""Download raw CSV bytes from an upstream source URL."""
|
| 65 |
+
try:
|
| 66 |
+
timeout = float(os.environ.get("DATAFORGE_DOWNLOAD_TIMEOUT_S", "5"))
|
| 67 |
+
except ValueError:
|
| 68 |
+
timeout = 5.0
|
| 69 |
+
with httpx.Client(timeout=timeout, follow_redirects=True) as client:
|
| 70 |
response = client.get(url)
|
| 71 |
response.raise_for_status()
|
| 72 |
return response.content
|
|
|
|
| 76 |
"""Download dirty/clean CSV files into the dataset cache directory."""
|
| 77 |
dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
dirty_url, clean_url = metadata.source_urls
|
| 79 |
+
_LOGGER.info("dataset_download_start name=%s dir=%s", metadata.name, dataset_dir)
|
| 80 |
(dataset_dir / "dirty.csv").write_bytes(_download_bytes(dirty_url))
|
| 81 |
(dataset_dir / "clean.csv").write_bytes(_download_bytes(clean_url))
|
| 82 |
+
_LOGGER.info("dataset_download_complete name=%s dir=%s", metadata.name, dataset_dir)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _load_embedded_dataset(name: str) -> tuple[pd.DataFrame, pd.DataFrame] | None:
|
| 86 |
+
root = Path(__file__).parent / "embedded" / name
|
| 87 |
+
dirty_path = root / "dirty.csv"
|
| 88 |
+
clean_path = root / "clean.csv"
|
| 89 |
+
if not dirty_path.exists() or not clean_path.exists():
|
| 90 |
+
return None
|
| 91 |
+
return _read_cached_csv(dirty_path), _read_cached_csv(clean_path)
|
| 92 |
|
| 93 |
|
| 94 |
def _manual_download_message(metadata: DatasetMetadata, dataset_dir: Path, cause: Exception) -> str:
|
|
|
|
| 173 |
dirty_path = dataset_dir / "dirty.csv"
|
| 174 |
clean_path = dataset_dir / "clean.csv"
|
| 175 |
|
| 176 |
+
dirty_df: pd.DataFrame | None = None
|
| 177 |
+
clean_df: pd.DataFrame | None = None
|
| 178 |
+
|
| 179 |
if not dirty_path.exists() or not clean_path.exists():
|
| 180 |
+
_LOGGER.info("dataset_cache_miss name=%s dir=%s", name, dataset_dir)
|
| 181 |
try:
|
| 182 |
_download_to_cache(metadata, dataset_dir)
|
| 183 |
except Exception as exc: # pragma: no cover - exercised through tests via monkeypatch
|
| 184 |
+
fallback = _load_embedded_dataset(name)
|
| 185 |
+
if fallback is None:
|
| 186 |
+
raise DatasetDownloadError(
|
| 187 |
+
_manual_download_message(metadata, dataset_dir, exc)
|
| 188 |
+
) from exc
|
| 189 |
+
dirty_df, clean_df = fallback
|
| 190 |
+
else:
|
| 191 |
+
_LOGGER.info("dataset_cache_hit name=%s dir=%s", name, dataset_dir)
|
| 192 |
+
|
| 193 |
+
if dirty_df is None or clean_df is None:
|
| 194 |
+
dirty_df = _read_cached_csv(dirty_path)
|
| 195 |
+
clean_df = _read_cached_csv(clean_path)
|
| 196 |
|
| 197 |
if len(dirty_df.index) != len(clean_df.index):
|
| 198 |
raise ValueError(f"Dataset '{name}' dirty/clean row counts do not match.")
|
dataforge/detectors/__init__.py
CHANGED
|
@@ -12,8 +12,6 @@ deduplicated, severity-sorted issue list.
|
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
| 15 |
-
import pandas as pd
|
| 16 |
-
|
| 17 |
from dataforge.detectors.base import Detector, Issue, Schema, Severity
|
| 18 |
from dataforge.detectors.decimal_shift import DecimalShiftDetector
|
| 19 |
from dataforge.detectors.fd_violation import FDViolationDetector
|
|
@@ -33,14 +31,14 @@ __all__ = [
|
|
| 33 |
_SEVERITY_ORDER = {Severity.UNSAFE: 0, Severity.REVIEW: 1, Severity.SAFE: 2}
|
| 34 |
|
| 35 |
|
| 36 |
-
def run_all_detectors(df:
|
| 37 |
"""Run all registered detectors and return a merged, sorted issue list.
|
| 38 |
|
| 39 |
Issues are deduplicated by (row, column, issue_type) and sorted by
|
| 40 |
severity (UNSAFE first) then confidence (highest first).
|
| 41 |
|
| 42 |
Args:
|
| 43 |
-
df: The input
|
| 44 |
schema: Optional declared schema with column types and constraints.
|
| 45 |
|
| 46 |
Returns:
|
|
|
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
|
|
|
|
|
| 15 |
from dataforge.detectors.base import Detector, Issue, Schema, Severity
|
| 16 |
from dataforge.detectors.decimal_shift import DecimalShiftDetector
|
| 17 |
from dataforge.detectors.fd_violation import FDViolationDetector
|
|
|
|
| 31 |
_SEVERITY_ORDER = {Severity.UNSAFE: 0, Severity.REVIEW: 1, Severity.SAFE: 2}
|
| 32 |
|
| 33 |
|
| 34 |
+
def run_all_detectors(df: object, schema: Schema | None = None) -> list[Issue]:
|
| 35 |
"""Run all registered detectors and return a merged, sorted issue list.
|
| 36 |
|
| 37 |
Issues are deduplicated by (row, column, issue_type) and sorted by
|
| 38 |
severity (UNSAFE first) then confidence (highest first).
|
| 39 |
|
| 40 |
Args:
|
| 41 |
+
df: The input table to analyze.
|
| 42 |
schema: Optional declared schema with column types and constraints.
|
| 43 |
|
| 44 |
Returns:
|
dataforge/detectors/base.py
CHANGED
|
@@ -5,9 +5,9 @@ from __future__ import annotations
|
|
| 5 |
import enum
|
| 6 |
from typing import Literal, Protocol
|
| 7 |
|
| 8 |
-
import pandas as pd
|
| 9 |
from pydantic import BaseModel, Field
|
| 10 |
|
|
|
|
| 11 |
from dataforge.verifier.schema import (
|
| 12 |
AggregateDependency,
|
| 13 |
DomainBound,
|
|
@@ -114,23 +114,23 @@ class Issue(BaseModel):
|
|
| 114 |
class Detector(Protocol):
|
| 115 |
"""Structural protocol that every detector must implement.
|
| 116 |
|
| 117 |
-
A detector is a pure function over tabular data: it receives a
|
| 118 |
and an optional Schema, and returns a list of Issue objects. No LLM calls,
|
| 119 |
no disk I/O, no side effects.
|
| 120 |
|
| 121 |
Example:
|
| 122 |
>>> class MyDetector:
|
| 123 |
... def detect(
|
| 124 |
-
... self, df:
|
| 125 |
... ) -> list[Issue]:
|
| 126 |
... return []
|
| 127 |
"""
|
| 128 |
|
| 129 |
-
def detect(self, df:
|
| 130 |
"""Detect data-quality issues in the given DataFrame.
|
| 131 |
|
| 132 |
Args:
|
| 133 |
-
df: The input
|
| 134 |
schema: Optional declared schema with column types and constraints.
|
| 135 |
|
| 136 |
Returns:
|
|
|
|
| 5 |
import enum
|
| 6 |
from typing import Literal, Protocol
|
| 7 |
|
|
|
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
|
| 10 |
+
from dataforge.table import TableLike
|
| 11 |
from dataforge.verifier.schema import (
|
| 12 |
AggregateDependency,
|
| 13 |
DomainBound,
|
|
|
|
| 114 |
class Detector(Protocol):
|
| 115 |
"""Structural protocol that every detector must implement.
|
| 116 |
|
| 117 |
+
A detector is a pure function over tabular data: it receives a table
|
| 118 |
and an optional Schema, and returns a list of Issue objects. No LLM calls,
|
| 119 |
no disk I/O, no side effects.
|
| 120 |
|
| 121 |
Example:
|
| 122 |
>>> class MyDetector:
|
| 123 |
... def detect(
|
| 124 |
+
... self, df: TableLike, schema: Schema | None = None
|
| 125 |
... ) -> list[Issue]:
|
| 126 |
... return []
|
| 127 |
"""
|
| 128 |
|
| 129 |
+
def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]:
|
| 130 |
"""Detect data-quality issues in the given DataFrame.
|
| 131 |
|
| 132 |
Args:
|
| 133 |
+
df: The input table to analyze.
|
| 134 |
schema: Optional declared schema with column types and constraints.
|
| 135 |
|
| 136 |
Returns:
|
dataforge/detectors/decimal_shift.py
CHANGED
|
@@ -10,15 +10,10 @@ The detector is **pure**: no LLM calls, no I/O, no side effects.
|
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
import math
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
-
import numpy as np
|
| 16 |
-
import pandas as pd
|
| 17 |
|
| 18 |
from dataforge.detectors.base import Issue, Schema, Severity
|
| 19 |
-
|
| 20 |
-
if TYPE_CHECKING:
|
| 21 |
-
pass
|
| 22 |
|
| 23 |
# Minimum non-null numeric values required for meaningful statistics.
|
| 24 |
_MIN_COLUMN_SIZE = 5
|
|
@@ -70,7 +65,7 @@ class DecimalShiftDetector:
|
|
| 70 |
3
|
| 71 |
"""
|
| 72 |
|
| 73 |
-
def detect(self, df:
|
| 74 |
"""Detect decimal-shift issues in the DataFrame.
|
| 75 |
|
| 76 |
Args:
|
|
@@ -83,13 +78,13 @@ class DecimalShiftDetector:
|
|
| 83 |
"""
|
| 84 |
issues: list[Issue] = []
|
| 85 |
|
| 86 |
-
for col_name in df
|
| 87 |
col_issues = self._check_column(df, str(col_name))
|
| 88 |
issues.extend(col_issues)
|
| 89 |
|
| 90 |
return issues
|
| 91 |
|
| 92 |
-
def _check_column(self, df:
|
| 93 |
"""Check a single column for decimal-shift outliers.
|
| 94 |
|
| 95 |
Args:
|
|
@@ -101,7 +96,7 @@ class DecimalShiftDetector:
|
|
| 101 |
"""
|
| 102 |
# Parse all values to float, keeping track of original indices.
|
| 103 |
parsed: list[tuple[int, float, str]] = []
|
| 104 |
-
for row_idx, val in enumerate(df
|
| 105 |
fval = _try_float(val)
|
| 106 |
if fval is not None:
|
| 107 |
parsed.append((row_idx, fval, str(val)))
|
|
@@ -109,11 +104,10 @@ class DecimalShiftDetector:
|
|
| 109 |
if len(parsed) < _MIN_COLUMN_SIZE:
|
| 110 |
return []
|
| 111 |
|
| 112 |
-
|
| 113 |
-
median = float(np.median(values))
|
| 114 |
|
| 115 |
# If median is zero or very close, we cannot compute meaningful ratios.
|
| 116 |
-
if abs(
|
| 117 |
return []
|
| 118 |
|
| 119 |
issues: list[Issue] = []
|
|
@@ -121,7 +115,7 @@ class DecimalShiftDetector:
|
|
| 121 |
if abs(fval) < 1e-10:
|
| 122 |
continue
|
| 123 |
|
| 124 |
-
ratio = fval /
|
| 125 |
if abs(ratio) < 1e-10:
|
| 126 |
continue
|
| 127 |
|
|
@@ -147,13 +141,13 @@ class DecimalShiftDetector:
|
|
| 147 |
reason = (
|
| 148 |
f"Value {fval:g} in column '{col_name}' appears to be "
|
| 149 |
f"~{int(correction_factor)}x the typical value "
|
| 150 |
-
f"(median ~{
|
| 151 |
)
|
| 152 |
else:
|
| 153 |
reason = (
|
| 154 |
f"Value {fval:g} in column '{col_name}' appears to be "
|
| 155 |
f"~{1.0 / correction_factor:g}x too small compared to "
|
| 156 |
-
f"the typical value (median ~{
|
| 157 |
)
|
| 158 |
|
| 159 |
issues.append(
|
|
|
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
import math
|
| 13 |
+
from statistics import median
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from dataforge.detectors.base import Issue, Schema, Severity
|
| 16 |
+
from dataforge.table import TableLike, column_names, column_values
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Minimum non-null numeric values required for meaningful statistics.
|
| 19 |
_MIN_COLUMN_SIZE = 5
|
|
|
|
| 65 |
3
|
| 66 |
"""
|
| 67 |
|
| 68 |
+
def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]:
|
| 69 |
"""Detect decimal-shift issues in the DataFrame.
|
| 70 |
|
| 71 |
Args:
|
|
|
|
| 78 |
"""
|
| 79 |
issues: list[Issue] = []
|
| 80 |
|
| 81 |
+
for col_name in column_names(df):
|
| 82 |
col_issues = self._check_column(df, str(col_name))
|
| 83 |
issues.extend(col_issues)
|
| 84 |
|
| 85 |
return issues
|
| 86 |
|
| 87 |
+
def _check_column(self, df: TableLike, col_name: str) -> list[Issue]:
|
| 88 |
"""Check a single column for decimal-shift outliers.
|
| 89 |
|
| 90 |
Args:
|
|
|
|
| 96 |
"""
|
| 97 |
# Parse all values to float, keeping track of original indices.
|
| 98 |
parsed: list[tuple[int, float, str]] = []
|
| 99 |
+
for row_idx, val in enumerate(column_values(df, col_name)):
|
| 100 |
fval = _try_float(val)
|
| 101 |
if fval is not None:
|
| 102 |
parsed.append((row_idx, fval, str(val)))
|
|
|
|
| 104 |
if len(parsed) < _MIN_COLUMN_SIZE:
|
| 105 |
return []
|
| 106 |
|
| 107 |
+
center = float(median([v for _, v, _ in parsed]))
|
|
|
|
| 108 |
|
| 109 |
# If median is zero or very close, we cannot compute meaningful ratios.
|
| 110 |
+
if abs(center) < 1e-10:
|
| 111 |
return []
|
| 112 |
|
| 113 |
issues: list[Issue] = []
|
|
|
|
| 115 |
if abs(fval) < 1e-10:
|
| 116 |
continue
|
| 117 |
|
| 118 |
+
ratio = fval / center
|
| 119 |
if abs(ratio) < 1e-10:
|
| 120 |
continue
|
| 121 |
|
|
|
|
| 141 |
reason = (
|
| 142 |
f"Value {fval:g} in column '{col_name}' appears to be "
|
| 143 |
f"~{int(correction_factor)}x the typical value "
|
| 144 |
+
f"(median ~{center:g})"
|
| 145 |
)
|
| 146 |
else:
|
| 147 |
reason = (
|
| 148 |
f"Value {fval:g} in column '{col_name}' appears to be "
|
| 149 |
f"~{1.0 / correction_factor:g}x too small compared to "
|
| 150 |
+
f"the typical value (median ~{center:g})"
|
| 151 |
)
|
| 152 |
|
| 153 |
issues.append(
|
dataforge/detectors/fd_violation.py
CHANGED
|
@@ -12,14 +12,8 @@ The detector is **pure**: no LLM calls, no I/O, no side effects.
|
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
| 15 |
-
from typing import TYPE_CHECKING
|
| 16 |
-
|
| 17 |
-
import pandas as pd
|
| 18 |
-
|
| 19 |
from dataforge.detectors.base import Issue, Schema, Severity
|
| 20 |
-
|
| 21 |
-
if TYPE_CHECKING:
|
| 22 |
-
pass
|
| 23 |
|
| 24 |
|
| 25 |
class FDViolationDetector:
|
|
@@ -49,7 +43,7 @@ class FDViolationDetector:
|
|
| 49 |
2
|
| 50 |
"""
|
| 51 |
|
| 52 |
-
def detect(self, df:
|
| 53 |
"""Detect FD-violation issues in the DataFrame.
|
| 54 |
|
| 55 |
Args:
|
|
@@ -73,7 +67,7 @@ class FDViolationDetector:
|
|
| 73 |
|
| 74 |
def _check_fd(
|
| 75 |
self,
|
| 76 |
-
df:
|
| 77 |
determinant: tuple[str, ...],
|
| 78 |
dependent: str,
|
| 79 |
) -> list[Issue]:
|
|
@@ -91,34 +85,37 @@ class FDViolationDetector:
|
|
| 91 |
|
| 92 |
# Verify all columns exist in the DataFrame.
|
| 93 |
all_cols = [*determinant_columns, dependent]
|
|
|
|
| 94 |
for col in all_cols:
|
| 95 |
-
if col not in
|
| 96 |
return []
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
if
|
| 104 |
return []
|
| 105 |
|
| 106 |
-
# Group by determinant and find groups with multiple distinct
|
| 107 |
-
# dependent values.
|
| 108 |
issues: list[Issue] = []
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
| 113 |
if len(unique_deps) <= 1:
|
| 114 |
continue
|
| 115 |
|
| 116 |
-
# All rows in this group are part of the violation.
|
| 117 |
det_desc = self._format_determinant(determinant, group_key)
|
| 118 |
unique_str = ", ".join(repr(str(v)) for v in unique_deps)
|
| 119 |
|
| 120 |
-
for idx in
|
| 121 |
-
actual_val =
|
| 122 |
reason = (
|
| 123 |
f"Functional dependency {determinant} -> {dependent} "
|
| 124 |
f"violated: {det_desc} maps to multiple values: "
|
|
|
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from dataforge.detectors.base import Issue, Schema, Severity
|
| 16 |
+
from dataforge.table import TableLike, cell_value, column_names, row_count
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class FDViolationDetector:
|
|
|
|
| 43 |
2
|
| 44 |
"""
|
| 45 |
|
| 46 |
+
def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]:
|
| 47 |
"""Detect FD-violation issues in the DataFrame.
|
| 48 |
|
| 49 |
Args:
|
|
|
|
| 67 |
|
| 68 |
def _check_fd(
|
| 69 |
self,
|
| 70 |
+
df: TableLike,
|
| 71 |
determinant: tuple[str, ...],
|
| 72 |
dependent: str,
|
| 73 |
) -> list[Issue]:
|
|
|
|
| 85 |
|
| 86 |
# Verify all columns exist in the DataFrame.
|
| 87 |
all_cols = [*determinant_columns, dependent]
|
| 88 |
+
available_columns = set(column_names(df))
|
| 89 |
for col in all_cols:
|
| 90 |
+
if col not in available_columns:
|
| 91 |
return []
|
| 92 |
|
| 93 |
+
groups: dict[tuple[str, ...], list[int]] = {}
|
| 94 |
+
for row in range(row_count(df)):
|
| 95 |
+
group_key = tuple(cell_value(df, row, column) for column in determinant_columns)
|
| 96 |
+
if any(value == "" for value in group_key):
|
| 97 |
+
continue
|
| 98 |
+
groups.setdefault(group_key, []).append(row)
|
| 99 |
|
| 100 |
+
if not groups:
|
| 101 |
return []
|
| 102 |
|
|
|
|
|
|
|
| 103 |
issues: list[Issue] = []
|
| 104 |
+
for group_key, row_indices in groups.items():
|
| 105 |
+
unique_deps: list[str] = []
|
| 106 |
+
for row in row_indices:
|
| 107 |
+
value = cell_value(df, row, dependent)
|
| 108 |
+
if value == "" or value in unique_deps:
|
| 109 |
+
continue
|
| 110 |
+
unique_deps.append(value)
|
| 111 |
if len(unique_deps) <= 1:
|
| 112 |
continue
|
| 113 |
|
|
|
|
| 114 |
det_desc = self._format_determinant(determinant, group_key)
|
| 115 |
unique_str = ", ".join(repr(str(v)) for v in unique_deps)
|
| 116 |
|
| 117 |
+
for idx in row_indices:
|
| 118 |
+
actual_val = cell_value(df, idx, dependent)
|
| 119 |
reason = (
|
| 120 |
f"Functional dependency {determinant} -> {dependent} "
|
| 121 |
f"violated: {det_desc} maps to multiple values: "
|
dataforge/detectors/type_mismatch.py
CHANGED
|
@@ -10,14 +10,9 @@ The detector is **pure**: no LLM calls, no I/O, no side effects.
|
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
import re
|
| 13 |
-
from typing import TYPE_CHECKING
|
| 14 |
-
|
| 15 |
-
import pandas as pd
|
| 16 |
|
| 17 |
from dataforge.detectors.base import Issue, Schema, Severity
|
| 18 |
-
|
| 19 |
-
if TYPE_CHECKING:
|
| 20 |
-
pass
|
| 21 |
|
| 22 |
# Compiled regexes for type inference.
|
| 23 |
_NUMERIC_RE = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?$")
|
|
@@ -69,7 +64,7 @@ class TypeMismatchDetector:
|
|
| 69 |
'N/A'
|
| 70 |
"""
|
| 71 |
|
| 72 |
-
def detect(self, df:
|
| 73 |
"""Detect type-mismatch issues in the DataFrame.
|
| 74 |
|
| 75 |
Args:
|
|
@@ -84,13 +79,13 @@ class TypeMismatchDetector:
|
|
| 84 |
"""
|
| 85 |
issues: list[Issue] = []
|
| 86 |
|
| 87 |
-
for col_name in df
|
| 88 |
col_issues = self._check_column(df, str(col_name))
|
| 89 |
issues.extend(col_issues)
|
| 90 |
|
| 91 |
return issues
|
| 92 |
|
| 93 |
-
def _check_column(self, df:
|
| 94 |
"""Check a single column for type mismatches.
|
| 95 |
|
| 96 |
Args:
|
|
@@ -100,12 +95,10 @@ class TypeMismatchDetector:
|
|
| 100 |
Returns:
|
| 101 |
Issues found in this column.
|
| 102 |
"""
|
| 103 |
-
series = df[col_name]
|
| 104 |
-
|
| 105 |
# Collect (index, value, type) for non-null entries.
|
| 106 |
classified: list[tuple[int, str, str]] = []
|
| 107 |
-
for row_idx, val in enumerate(
|
| 108 |
-
if
|
| 109 |
continue
|
| 110 |
str_val = str(val).strip()
|
| 111 |
if not str_val:
|
|
|
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
import re
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
from dataforge.detectors.base import Issue, Schema, Severity
|
| 15 |
+
from dataforge.table import TableLike, column_names, column_values
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Compiled regexes for type inference.
|
| 18 |
_NUMERIC_RE = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?$")
|
|
|
|
| 64 |
'N/A'
|
| 65 |
"""
|
| 66 |
|
| 67 |
+
def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]:
|
| 68 |
"""Detect type-mismatch issues in the DataFrame.
|
| 69 |
|
| 70 |
Args:
|
|
|
|
| 79 |
"""
|
| 80 |
issues: list[Issue] = []
|
| 81 |
|
| 82 |
+
for col_name in column_names(df):
|
| 83 |
col_issues = self._check_column(df, str(col_name))
|
| 84 |
issues.extend(col_issues)
|
| 85 |
|
| 86 |
return issues
|
| 87 |
|
| 88 |
+
def _check_column(self, df: TableLike, col_name: str) -> list[Issue]:
|
| 89 |
"""Check a single column for type mismatches.
|
| 90 |
|
| 91 |
Args:
|
|
|
|
| 95 |
Returns:
|
| 96 |
Issues found in this column.
|
| 97 |
"""
|
|
|
|
|
|
|
| 98 |
# Collect (index, value, type) for non-null entries.
|
| 99 |
classified: list[tuple[int, str, str]] = []
|
| 100 |
+
for row_idx, val in enumerate(column_values(df, col_name)):
|
| 101 |
+
if val is None:
|
| 102 |
continue
|
| 103 |
str_val = str(val).strip()
|
| 104 |
if not str_val:
|
dataforge/engine/__init__.py
CHANGED
|
@@ -1 +1,33 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public backend engine APIs for DataForge."""
|
| 2 |
+
|
| 3 |
+
from dataforge.engine.repair import (
|
| 4 |
+
CandidateFix,
|
| 5 |
+
RepairFailure,
|
| 6 |
+
RepairMode,
|
| 7 |
+
RepairPipelineRequest,
|
| 8 |
+
RepairPipelineResult,
|
| 9 |
+
RepairReceipt,
|
| 10 |
+
VerifiedFix,
|
| 11 |
+
apply_fixes_to_csv,
|
| 12 |
+
apply_transaction,
|
| 13 |
+
create_repair_transaction,
|
| 14 |
+
propose_repairs,
|
| 15 |
+
run_repair_pipeline,
|
| 16 |
+
source_path_lock,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"CandidateFix",
|
| 21 |
+
"RepairFailure",
|
| 22 |
+
"RepairMode",
|
| 23 |
+
"RepairPipelineRequest",
|
| 24 |
+
"RepairPipelineResult",
|
| 25 |
+
"RepairReceipt",
|
| 26 |
+
"VerifiedFix",
|
| 27 |
+
"apply_fixes_to_csv",
|
| 28 |
+
"apply_transaction",
|
| 29 |
+
"create_repair_transaction",
|
| 30 |
+
"propose_repairs",
|
| 31 |
+
"run_repair_pipeline",
|
| 32 |
+
"source_path_lock",
|
| 33 |
+
]
|
dataforge/engine/repair.py
ADDED
|
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public repair engine for DataForge backend surfaces.
|
| 2 |
+
|
| 3 |
+
The engine is the stable boundary shared by CLI, Playground, MCP, and any
|
| 4 |
+
OpenEnv adapter that needs repair semantics. It keeps the core invariant in one
|
| 5 |
+
place: detect -> propose -> safety -> SMT verification -> journal/snapshot ->
|
| 6 |
+
atomic mutation -> byte-identical revert.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import hashlib
|
| 12 |
+
import os
|
| 13 |
+
import secrets
|
| 14 |
+
import time
|
| 15 |
+
from collections.abc import Callable, Iterator
|
| 16 |
+
from contextlib import contextmanager, suppress
|
| 17 |
+
from datetime import UTC, datetime
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Literal
|
| 20 |
+
|
| 21 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 22 |
+
|
| 23 |
+
from dataforge.detectors import run_all_detectors
|
| 24 |
+
from dataforge.detectors.base import Issue, Schema
|
| 25 |
+
from dataforge.observability import repair_stage_span
|
| 26 |
+
from dataforge.repair_contract import CONTRACT_VERSION
|
| 27 |
+
from dataforge.repairers import build_repairers
|
| 28 |
+
from dataforge.repairers.base import ProposedFix, RepairAttempt, RetryContext
|
| 29 |
+
from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict
|
| 30 |
+
from dataforge.table import (
|
| 31 |
+
Table,
|
| 32 |
+
TableLike,
|
| 33 |
+
cell_value,
|
| 34 |
+
column_names,
|
| 35 |
+
copy_table,
|
| 36 |
+
row_count,
|
| 37 |
+
set_cell_value,
|
| 38 |
+
table_to_csv_bytes,
|
| 39 |
+
)
|
| 40 |
+
from dataforge.table import (
|
| 41 |
+
read_csv as read_table_csv,
|
| 42 |
+
)
|
| 43 |
+
from dataforge.transactions.log import (
|
| 44 |
+
append_applied_event,
|
| 45 |
+
append_created_transaction,
|
| 46 |
+
cache_dir_for,
|
| 47 |
+
sha256_bytes,
|
| 48 |
+
sha256_file,
|
| 49 |
+
snapshot_path_for,
|
| 50 |
+
)
|
| 51 |
+
from dataforge.transactions.txn import CellFix, RepairTransaction, generate_txn_id
|
| 52 |
+
from dataforge.verifier import SMTVerifier, VerificationVerdict
|
| 53 |
+
|
| 54 |
+
RepairMode = Literal["dry_run", "apply"]
|
| 55 |
+
EscalationResolver = Callable[
|
| 56 |
+
[ProposedFix, Schema | None, SafetyContext, SafetyFilter, SafetyResult],
|
| 57 |
+
tuple[SafetyContext, SafetyResult],
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class RepairEngineError(RuntimeError):
|
| 62 |
+
"""Base exception for public repair engine failures."""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class TransactionApplyError(RepairEngineError):
|
| 66 |
+
"""Raised when an apply transaction cannot be completed safely."""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class CandidateFix(BaseModel):
|
| 70 |
+
"""Stable public representation of a proposed cell repair."""
|
| 71 |
+
|
| 72 |
+
row: int = Field(ge=0)
|
| 73 |
+
column: str = Field(min_length=1)
|
| 74 |
+
old_value: str
|
| 75 |
+
new_value: str
|
| 76 |
+
detector_id: str = Field(min_length=1)
|
| 77 |
+
operation: Literal["update", "delete_row"] = "update"
|
| 78 |
+
reason: str = Field(min_length=1)
|
| 79 |
+
confidence: float = Field(ge=0.0, le=1.0)
|
| 80 |
+
provenance: str = Field(min_length=1)
|
| 81 |
+
|
| 82 |
+
model_config = ConfigDict(strict=True, extra="forbid", frozen=True)
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def from_proposed(cls, proposed_fix: ProposedFix) -> CandidateFix:
|
| 86 |
+
"""Create a public candidate from an internal repair proposal."""
|
| 87 |
+
fix = proposed_fix.fix
|
| 88 |
+
return cls(
|
| 89 |
+
row=fix.row,
|
| 90 |
+
column=fix.column,
|
| 91 |
+
old_value=fix.old_value,
|
| 92 |
+
new_value=fix.new_value,
|
| 93 |
+
detector_id=fix.detector_id,
|
| 94 |
+
operation=fix.operation,
|
| 95 |
+
reason=proposed_fix.reason,
|
| 96 |
+
confidence=proposed_fix.confidence,
|
| 97 |
+
provenance=proposed_fix.provenance,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class VerifiedFix(CandidateFix):
|
| 102 |
+
"""A candidate that passed safety and SMT verification."""
|
| 103 |
+
|
| 104 |
+
verifier_reason: str = Field(min_length=1)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class RepairFailure(BaseModel):
|
| 108 |
+
"""Machine-readable account of an issue that could not be repaired."""
|
| 109 |
+
|
| 110 |
+
row: int = Field(ge=0)
|
| 111 |
+
column: str = Field(min_length=1)
|
| 112 |
+
issue_type: str = Field(min_length=1)
|
| 113 |
+
status: str = Field(min_length=1)
|
| 114 |
+
reason: str = Field(min_length=1)
|
| 115 |
+
attempt_count: int = Field(ge=1)
|
| 116 |
+
unsat_core: tuple[str, ...] = Field(default_factory=tuple)
|
| 117 |
+
|
| 118 |
+
model_config = ConfigDict(strict=True, extra="forbid", frozen=True)
|
| 119 |
+
|
| 120 |
+
@classmethod
|
| 121 |
+
def from_attempts(cls, attempts: list[RepairAttempt]) -> RepairFailure:
|
| 122 |
+
"""Build a public failure record from one issue's attempt trace."""
|
| 123 |
+
final = attempts[-1]
|
| 124 |
+
issue = final.issue
|
| 125 |
+
return cls(
|
| 126 |
+
row=issue.row,
|
| 127 |
+
column=issue.column,
|
| 128 |
+
issue_type=issue.issue_type,
|
| 129 |
+
status=final.status,
|
| 130 |
+
reason=final.reason,
|
| 131 |
+
attempt_count=len(attempts),
|
| 132 |
+
unsat_core=tuple(final.unsat_core),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class RepairReceipt(BaseModel):
|
| 137 |
+
"""Stable receipt for a dry-run or applied repair pipeline run."""
|
| 138 |
+
|
| 139 |
+
contract_version: str = CONTRACT_VERSION
|
| 140 |
+
mode: RepairMode
|
| 141 |
+
applied: bool
|
| 142 |
+
reversible: bool
|
| 143 |
+
source_path: str
|
| 144 |
+
source_sha256: str = Field(pattern=r"^[0-9a-f]{64}$")
|
| 145 |
+
post_sha256: str | None = Field(default=None, pattern=r"^[0-9a-f]{64}$")
|
| 146 |
+
txn_id: str | None = None
|
| 147 |
+
allowed_columns: list[str] = Field(default_factory=list)
|
| 148 |
+
valid_rows: list[int] = Field(default_factory=list)
|
| 149 |
+
issues_count: int = Field(ge=0)
|
| 150 |
+
fixes_count: int = Field(ge=0)
|
| 151 |
+
reason: str = Field(min_length=1)
|
| 152 |
+
|
| 153 |
+
model_config = ConfigDict(strict=True, extra="forbid", frozen=True)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class RepairPipelineRequest(BaseModel):
|
| 157 |
+
"""Input contract for running the public repair pipeline."""
|
| 158 |
+
|
| 159 |
+
source_path: Path
|
| 160 |
+
mode: RepairMode = "dry_run"
|
| 161 |
+
repair_schema: Schema | None = Field(default=None, alias="schema")
|
| 162 |
+
allow_llm: bool = False
|
| 163 |
+
model: str = "gemini-2.0-flash"
|
| 164 |
+
allow_pii: bool = False
|
| 165 |
+
confirm_pii: bool = False
|
| 166 |
+
confirm_escalations: bool = False
|
| 167 |
+
interactive: bool = False
|
| 168 |
+
create_dry_run_transaction: bool = False
|
| 169 |
+
|
| 170 |
+
model_config = ConfigDict(
|
| 171 |
+
strict=True,
|
| 172 |
+
arbitrary_types_allowed=True,
|
| 173 |
+
extra="forbid",
|
| 174 |
+
frozen=True,
|
| 175 |
+
populate_by_name=True,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class RepairPipelineResult(BaseModel):
|
| 180 |
+
"""Output contract for a public repair pipeline run."""
|
| 181 |
+
|
| 182 |
+
receipt: RepairReceipt
|
| 183 |
+
issues: list[Issue]
|
| 184 |
+
fixes: list[VerifiedFix]
|
| 185 |
+
failures: list[RepairFailure] = Field(default_factory=list)
|
| 186 |
+
transaction: RepairTransaction | None = None
|
| 187 |
+
|
| 188 |
+
model_config = ConfigDict(
|
| 189 |
+
strict=True, arbitrary_types_allowed=True, extra="forbid", frozen=True
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _atomic_write_bytes(path: Path, payload: bytes) -> None:
|
| 194 |
+
"""Write bytes to ``path`` through an atomic same-directory replacement."""
|
| 195 |
+
resolved = path.resolve()
|
| 196 |
+
resolved.parent.mkdir(parents=True, exist_ok=True)
|
| 197 |
+
temp_path = resolved.with_name(f".{resolved.name}.{secrets.token_hex(8)}.tmp")
|
| 198 |
+
try:
|
| 199 |
+
with temp_path.open("xb") as handle:
|
| 200 |
+
handle.write(payload)
|
| 201 |
+
handle.flush()
|
| 202 |
+
os.fsync(handle.fileno())
|
| 203 |
+
os.replace(temp_path, resolved)
|
| 204 |
+
finally:
|
| 205 |
+
if temp_path.exists():
|
| 206 |
+
temp_path.unlink()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def read_csv(path: Path) -> Table:
|
| 210 |
+
"""Read a CSV using conservative string-preserving defaults."""
|
| 211 |
+
return read_table_csv(path)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _csv_bytes_after_fixes(path: Path, fixes: list[CellFix]) -> bytes:
|
| 215 |
+
"""Validate fixes against a CSV and return the mutated CSV bytes."""
|
| 216 |
+
df = read_csv(path)
|
| 217 |
+
for fix in fixes:
|
| 218 |
+
if fix.operation != "update":
|
| 219 |
+
raise ValueError(f"Unsupported repair operation '{fix.operation}' for row {fix.row}.")
|
| 220 |
+
if fix.column not in column_names(df):
|
| 221 |
+
raise ValueError(f"Column '{fix.column}' not found in '{path}'.")
|
| 222 |
+
if fix.row < 0 or fix.row >= row_count(df):
|
| 223 |
+
raise ValueError(f"Row {fix.row} is out of bounds for '{path}'.")
|
| 224 |
+
|
| 225 |
+
current_value = cell_value(df, fix.row, fix.column)
|
| 226 |
+
if current_value != fix.old_value:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
f"Refusing to apply stale fix for row {fix.row}, column '{fix.column}': "
|
| 229 |
+
f"expected '{fix.old_value}', found '{current_value}'."
|
| 230 |
+
)
|
| 231 |
+
set_cell_value(df, fix.row, fix.column, fix.new_value)
|
| 232 |
+
|
| 233 |
+
return table_to_csv_bytes(df)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def apply_fixes_to_csv(path: Path, fixes: list[CellFix]) -> str:
|
| 237 |
+
"""Atomically apply ordered cell fixes to a CSV and return post-state SHA-256."""
|
| 238 |
+
payload = _csv_bytes_after_fixes(path, fixes)
|
| 239 |
+
_atomic_write_bytes(path, payload)
|
| 240 |
+
return hashlib.sha256(payload).hexdigest()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _lock_path_for(source_path: Path) -> Path:
|
| 244 |
+
"""Return the filesystem lock path for a source file."""
|
| 245 |
+
digest = hashlib.sha256(str(source_path.resolve()).encode("utf-8")).hexdigest()[:24]
|
| 246 |
+
return source_path.resolve().parent / ".dataforge" / "locks" / f"{digest}.lock"
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@contextmanager
|
| 250 |
+
def source_path_lock(
|
| 251 |
+
source_path: Path,
|
| 252 |
+
*,
|
| 253 |
+
timeout_seconds: float = 5.0,
|
| 254 |
+
stale_after_seconds: float = 300.0,
|
| 255 |
+
) -> Iterator[None]:
|
| 256 |
+
"""Acquire an exclusive lock for a source path using an atomic lock file."""
|
| 257 |
+
lock_path = _lock_path_for(source_path)
|
| 258 |
+
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
| 259 |
+
deadline = time.monotonic() + timeout_seconds
|
| 260 |
+
while True:
|
| 261 |
+
try:
|
| 262 |
+
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
| 263 |
+
try:
|
| 264 |
+
payload = f"{os.getpid()} {datetime.now(UTC).isoformat()}\n".encode()
|
| 265 |
+
os.write(fd, payload)
|
| 266 |
+
finally:
|
| 267 |
+
os.close(fd)
|
| 268 |
+
break
|
| 269 |
+
except FileExistsError as exc:
|
| 270 |
+
try:
|
| 271 |
+
age = time.time() - lock_path.stat().st_mtime
|
| 272 |
+
except OSError:
|
| 273 |
+
age = 0.0
|
| 274 |
+
if age > stale_after_seconds:
|
| 275 |
+
try:
|
| 276 |
+
lock_path.unlink()
|
| 277 |
+
continue
|
| 278 |
+
except OSError:
|
| 279 |
+
pass
|
| 280 |
+
if time.monotonic() >= deadline:
|
| 281 |
+
raise TransactionApplyError(
|
| 282 |
+
f"Timed out waiting for DataForge source lock: {source_path.resolve()}"
|
| 283 |
+
) from exc
|
| 284 |
+
time.sleep(0.05)
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
yield
|
| 288 |
+
finally:
|
| 289 |
+
with suppress(FileNotFoundError):
|
| 290 |
+
lock_path.unlink()
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _write_snapshot_once(snapshot_path: Path, source_bytes: bytes) -> None:
|
| 294 |
+
"""Write an immutable snapshot and fail if the transaction id already exists."""
|
| 295 |
+
snapshot_path.parent.mkdir(parents=True, exist_ok=True)
|
| 296 |
+
try:
|
| 297 |
+
with snapshot_path.open("xb") as handle:
|
| 298 |
+
handle.write(source_bytes)
|
| 299 |
+
handle.flush()
|
| 300 |
+
os.fsync(handle.fileno())
|
| 301 |
+
except FileExistsError as exc:
|
| 302 |
+
raise TransactionApplyError(
|
| 303 |
+
f"Transaction snapshot already exists: {snapshot_path}"
|
| 304 |
+
) from exc
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def create_repair_transaction(
|
| 308 |
+
path: Path,
|
| 309 |
+
fixes: list[ProposedFix],
|
| 310 |
+
source_bytes: bytes,
|
| 311 |
+
*,
|
| 312 |
+
txn_id: str | None = None,
|
| 313 |
+
) -> tuple[RepairTransaction, Path]:
|
| 314 |
+
"""Create an unapplied transaction journal and immutable source snapshot."""
|
| 315 |
+
resolved_path = path.resolve()
|
| 316 |
+
transaction_id = txn_id or generate_txn_id()
|
| 317 |
+
snapshot_path = snapshot_path_for(resolved_path, transaction_id)
|
| 318 |
+
_write_snapshot_once(snapshot_path, source_bytes)
|
| 319 |
+
|
| 320 |
+
transaction = RepairTransaction(
|
| 321 |
+
txn_id=transaction_id,
|
| 322 |
+
created_at=datetime.now(UTC),
|
| 323 |
+
source_path=str(resolved_path),
|
| 324 |
+
source_sha256=sha256_bytes(source_bytes),
|
| 325 |
+
source_snapshot_path=str(snapshot_path.resolve()),
|
| 326 |
+
fixes=[proposal.fix for proposal in fixes],
|
| 327 |
+
applied=False,
|
| 328 |
+
)
|
| 329 |
+
try:
|
| 330 |
+
log_path = append_created_transaction(transaction)
|
| 331 |
+
except Exception:
|
| 332 |
+
snapshot_path.unlink(missing_ok=True)
|
| 333 |
+
raise
|
| 334 |
+
return transaction, log_path
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def apply_transaction(
|
| 338 |
+
path: Path,
|
| 339 |
+
fixes: list[ProposedFix],
|
| 340 |
+
source_bytes: bytes,
|
| 341 |
+
*,
|
| 342 |
+
txn_id: str | None = None,
|
| 343 |
+
) -> str:
|
| 344 |
+
"""Journal, snapshot, atomically apply fixes, and restore bytes on failure."""
|
| 345 |
+
resolved_path = path.resolve()
|
| 346 |
+
with source_path_lock(resolved_path):
|
| 347 |
+
current_bytes = resolved_path.read_bytes()
|
| 348 |
+
if current_bytes != source_bytes:
|
| 349 |
+
raise TransactionApplyError(
|
| 350 |
+
"Refusing to apply repairs because the source file changed after detection."
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
with repair_stage_span("dataforge.repair.transaction.create", fixes_count=len(fixes)):
|
| 354 |
+
transaction, log_path = create_repair_transaction(
|
| 355 |
+
resolved_path,
|
| 356 |
+
fixes,
|
| 357 |
+
source_bytes,
|
| 358 |
+
txn_id=txn_id,
|
| 359 |
+
)
|
| 360 |
+
try:
|
| 361 |
+
with repair_stage_span("dataforge.repair.transaction.apply", fixes_count=len(fixes)):
|
| 362 |
+
post_sha256 = apply_fixes_to_csv(
|
| 363 |
+
resolved_path,
|
| 364 |
+
[proposal.fix for proposal in fixes],
|
| 365 |
+
)
|
| 366 |
+
append_applied_event(log_path, transaction.txn_id, post_sha256=post_sha256)
|
| 367 |
+
except Exception as exc:
|
| 368 |
+
_atomic_write_bytes(resolved_path, source_bytes)
|
| 369 |
+
if sha256_file(resolved_path) != transaction.source_sha256:
|
| 370 |
+
raise TransactionApplyError(
|
| 371 |
+
"Apply failed and the source file could not be restored to original bytes."
|
| 372 |
+
) from exc
|
| 373 |
+
raise
|
| 374 |
+
|
| 375 |
+
return transaction.txn_id
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def _build_retry_context(issue: Issue, attempts: list[RepairAttempt]) -> RetryContext:
|
| 379 |
+
"""Build retry hints from previous failed attempts."""
|
| 380 |
+
rejected_values = frozenset(
|
| 381 |
+
attempt.fix.fix.new_value
|
| 382 |
+
for attempt in attempts
|
| 383 |
+
if attempt.fix is not None and attempt.status in {"denied", "rejected", "unknown"}
|
| 384 |
+
)
|
| 385 |
+
hints: list[str] = []
|
| 386 |
+
for attempt in attempts:
|
| 387 |
+
hints.append(attempt.reason)
|
| 388 |
+
hints.extend(attempt.unsat_core)
|
| 389 |
+
return RetryContext(
|
| 390 |
+
issue=issue,
|
| 391 |
+
previous_attempts=tuple(attempts),
|
| 392 |
+
rejected_values=rejected_values,
|
| 393 |
+
hints=tuple(hints),
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def propose_repairs(
|
| 398 |
+
issues: list[Issue],
|
| 399 |
+
path: Path,
|
| 400 |
+
working_df: TableLike,
|
| 401 |
+
schema: Schema | None,
|
| 402 |
+
*,
|
| 403 |
+
allow_llm: bool,
|
| 404 |
+
model: str,
|
| 405 |
+
allow_pii: bool,
|
| 406 |
+
confirm_pii: bool,
|
| 407 |
+
confirm_escalations: bool,
|
| 408 |
+
interactive: bool,
|
| 409 |
+
escalation_resolver: EscalationResolver | None = None,
|
| 410 |
+
) -> tuple[list[ProposedFix], list[list[RepairAttempt]]]:
|
| 411 |
+
"""Run repairers and gates issue-by-issue against a working dataframe."""
|
| 412 |
+
with repair_stage_span("dataforge.repair.repairers.build", allow_llm=allow_llm):
|
| 413 |
+
repairers = build_repairers(
|
| 414 |
+
cache_dir=cache_dir_for(path),
|
| 415 |
+
allow_llm=allow_llm,
|
| 416 |
+
model=model,
|
| 417 |
+
)
|
| 418 |
+
safety_filter = SafetyFilter()
|
| 419 |
+
verifier = SMTVerifier()
|
| 420 |
+
safety_context = SafetyContext(
|
| 421 |
+
allow_pii=allow_pii,
|
| 422 |
+
confirm_pii=confirm_pii,
|
| 423 |
+
confirm_escalations=confirm_escalations,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
accepted_fixes: list[ProposedFix] = []
|
| 427 |
+
attempt_groups: list[list[RepairAttempt]] = []
|
| 428 |
+
|
| 429 |
+
for issue in issues:
|
| 430 |
+
attempts: list[RepairAttempt] = []
|
| 431 |
+
repairer = repairers.get(issue.issue_type)
|
| 432 |
+
if repairer is None:
|
| 433 |
+
attempts.append(
|
| 434 |
+
RepairAttempt(
|
| 435 |
+
issue=issue,
|
| 436 |
+
attempt_number=1,
|
| 437 |
+
status="attempted_not_fixed",
|
| 438 |
+
reason="No repairer is registered for this issue type.",
|
| 439 |
+
)
|
| 440 |
+
)
|
| 441 |
+
attempt_groups.append(attempts)
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
accepted = False
|
| 445 |
+
retry_context = RetryContext(issue=issue)
|
| 446 |
+
for attempt_number in range(1, 4):
|
| 447 |
+
candidate = repairer.propose(issue, working_df, schema, retry_context=retry_context)
|
| 448 |
+
if candidate is None:
|
| 449 |
+
attempts.append(
|
| 450 |
+
RepairAttempt(
|
| 451 |
+
issue=issue,
|
| 452 |
+
attempt_number=attempt_number,
|
| 453 |
+
status="attempted_not_fixed",
|
| 454 |
+
reason="No repair proposal was available for this issue.",
|
| 455 |
+
)
|
| 456 |
+
)
|
| 457 |
+
break
|
| 458 |
+
|
| 459 |
+
preferred = safety_filter.choose_preferred([candidate], schema, safety_context)
|
| 460 |
+
safety_result = safety_filter.evaluate(preferred, schema, safety_context)
|
| 461 |
+
if (
|
| 462 |
+
safety_result.verdict == SafetyVerdict.ESCALATE
|
| 463 |
+
and interactive
|
| 464 |
+
and escalation_resolver is not None
|
| 465 |
+
):
|
| 466 |
+
safety_context, safety_result = escalation_resolver(
|
| 467 |
+
preferred,
|
| 468 |
+
schema,
|
| 469 |
+
safety_context,
|
| 470 |
+
safety_filter,
|
| 471 |
+
safety_result,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
if safety_result.verdict == SafetyVerdict.DENY:
|
| 475 |
+
attempts.append(
|
| 476 |
+
RepairAttempt(
|
| 477 |
+
issue=issue,
|
| 478 |
+
attempt_number=attempt_number,
|
| 479 |
+
fix=preferred,
|
| 480 |
+
status="denied",
|
| 481 |
+
reason=safety_result.reason,
|
| 482 |
+
)
|
| 483 |
+
)
|
| 484 |
+
retry_context = _build_retry_context(issue, attempts)
|
| 485 |
+
continue
|
| 486 |
+
|
| 487 |
+
if safety_result.verdict == SafetyVerdict.ESCALATE:
|
| 488 |
+
attempts.append(
|
| 489 |
+
RepairAttempt(
|
| 490 |
+
issue=issue,
|
| 491 |
+
attempt_number=attempt_number,
|
| 492 |
+
fix=preferred,
|
| 493 |
+
status="escalated",
|
| 494 |
+
reason=safety_result.reason,
|
| 495 |
+
)
|
| 496 |
+
)
|
| 497 |
+
break
|
| 498 |
+
|
| 499 |
+
with repair_stage_span(
|
| 500 |
+
"dataforge.repair.verifier.verify",
|
| 501 |
+
issue_type=issue.issue_type,
|
| 502 |
+
row=issue.row,
|
| 503 |
+
):
|
| 504 |
+
verifier_result = verifier.verify(working_df, [preferred], schema)
|
| 505 |
+
if verifier_result.verdict == VerificationVerdict.ACCEPT:
|
| 506 |
+
accepted_fixes.append(preferred)
|
| 507 |
+
set_cell_value(
|
| 508 |
+
working_df,
|
| 509 |
+
preferred.fix.row,
|
| 510 |
+
preferred.fix.column,
|
| 511 |
+
preferred.fix.new_value,
|
| 512 |
+
)
|
| 513 |
+
attempts.append(
|
| 514 |
+
RepairAttempt(
|
| 515 |
+
issue=issue,
|
| 516 |
+
attempt_number=attempt_number,
|
| 517 |
+
fix=preferred,
|
| 518 |
+
status="accepted",
|
| 519 |
+
reason=verifier_result.reason,
|
| 520 |
+
)
|
| 521 |
+
)
|
| 522 |
+
accepted = True
|
| 523 |
+
break
|
| 524 |
+
|
| 525 |
+
attempts.append(
|
| 526 |
+
RepairAttempt(
|
| 527 |
+
issue=issue,
|
| 528 |
+
attempt_number=attempt_number,
|
| 529 |
+
fix=preferred,
|
| 530 |
+
status=(
|
| 531 |
+
"rejected"
|
| 532 |
+
if verifier_result.verdict == VerificationVerdict.REJECT
|
| 533 |
+
else "unknown"
|
| 534 |
+
),
|
| 535 |
+
reason=verifier_result.reason,
|
| 536 |
+
unsat_core=verifier_result.unsat_core,
|
| 537 |
+
)
|
| 538 |
+
)
|
| 539 |
+
retry_context = _build_retry_context(issue, attempts)
|
| 540 |
+
|
| 541 |
+
if (
|
| 542 |
+
not accepted
|
| 543 |
+
and attempts
|
| 544 |
+
and attempts[-1].status not in {"attempted_not_fixed", "escalated"}
|
| 545 |
+
):
|
| 546 |
+
last_reason = attempts[-1].reason
|
| 547 |
+
attempts[-1] = attempts[-1].model_copy(
|
| 548 |
+
update={
|
| 549 |
+
"status": "attempted_not_fixed",
|
| 550 |
+
"reason": (
|
| 551 |
+
f"Issue was attempted but not fixed after {len(attempts)} attempt(s). "
|
| 552 |
+
f"Last failure: {last_reason}"
|
| 553 |
+
),
|
| 554 |
+
}
|
| 555 |
+
)
|
| 556 |
+
attempt_groups.append(attempts)
|
| 557 |
+
|
| 558 |
+
return accepted_fixes, attempt_groups
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def _verified_fixes(
|
| 562 |
+
fixes: list[ProposedFix],
|
| 563 |
+
attempt_groups: list[list[RepairAttempt]],
|
| 564 |
+
) -> list[VerifiedFix]:
|
| 565 |
+
"""Build public verified fix payloads using accepted attempt reasons."""
|
| 566 |
+
accepted_reasons: dict[tuple[int, str, str], str] = {}
|
| 567 |
+
for attempts in attempt_groups:
|
| 568 |
+
for attempt in attempts:
|
| 569 |
+
if attempt.status == "accepted" and attempt.fix is not None:
|
| 570 |
+
fix = attempt.fix.fix
|
| 571 |
+
accepted_reasons[(fix.row, fix.column, fix.new_value)] = attempt.reason
|
| 572 |
+
|
| 573 |
+
return [
|
| 574 |
+
VerifiedFix(
|
| 575 |
+
**CandidateFix.from_proposed(fix).model_dump(),
|
| 576 |
+
verifier_reason=accepted_reasons.get(
|
| 577 |
+
(fix.fix.row, fix.fix.column, fix.fix.new_value),
|
| 578 |
+
"Accepted by verifier.",
|
| 579 |
+
),
|
| 580 |
+
)
|
| 581 |
+
for fix in fixes
|
| 582 |
+
]
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def _failed_attempts(attempt_groups: list[list[RepairAttempt]]) -> list[RepairFailure]:
|
| 586 |
+
"""Return failures for issue groups whose final status was not accepted."""
|
| 587 |
+
return [
|
| 588 |
+
RepairFailure.from_attempts(attempts)
|
| 589 |
+
for attempts in attempt_groups
|
| 590 |
+
if attempts and attempts[-1].status != "accepted"
|
| 591 |
+
]
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def run_repair_pipeline(request: RepairPipelineRequest) -> RepairPipelineResult:
|
| 595 |
+
"""Run the public repair pipeline from detection through optional apply."""
|
| 596 |
+
source_path = request.source_path.resolve()
|
| 597 |
+
source_bytes = source_path.read_bytes()
|
| 598 |
+
df = read_csv(source_path)
|
| 599 |
+
with repair_stage_span("dataforge.repair.detect", row_count=row_count(df)):
|
| 600 |
+
issues = run_all_detectors(df, request.repair_schema)
|
| 601 |
+
with repair_stage_span("dataforge.repair.propose", issue_count=len(issues)):
|
| 602 |
+
accepted_fixes, attempt_groups = propose_repairs(
|
| 603 |
+
issues,
|
| 604 |
+
source_path,
|
| 605 |
+
copy_table(df),
|
| 606 |
+
request.repair_schema,
|
| 607 |
+
allow_llm=request.allow_llm,
|
| 608 |
+
model=request.model,
|
| 609 |
+
allow_pii=request.allow_pii,
|
| 610 |
+
confirm_pii=request.confirm_pii,
|
| 611 |
+
confirm_escalations=request.confirm_escalations,
|
| 612 |
+
interactive=request.interactive,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
with repair_stage_span("dataforge.repair.safety.batch", fixes_count=len(accepted_fixes)):
|
| 616 |
+
batch_safety = SafetyFilter().evaluate_batch(accepted_fixes)
|
| 617 |
+
failures = _failed_attempts(attempt_groups)
|
| 618 |
+
transaction: RepairTransaction | None = None
|
| 619 |
+
txn_id: str | None = None
|
| 620 |
+
post_sha256: str | None = None
|
| 621 |
+
applied = False
|
| 622 |
+
reason = "No accepted fixes were produced."
|
| 623 |
+
|
| 624 |
+
if batch_safety.verdict != SafetyVerdict.ALLOW:
|
| 625 |
+
accepted_fixes = []
|
| 626 |
+
reason = batch_safety.reason
|
| 627 |
+
elif request.mode == "apply" and accepted_fixes:
|
| 628 |
+
txn_id = apply_transaction(source_path, accepted_fixes, source_bytes)
|
| 629 |
+
post_sha256 = sha256_file(source_path)
|
| 630 |
+
applied = True
|
| 631 |
+
reason = f"Applied {len(accepted_fixes)} fix(es)."
|
| 632 |
+
elif request.create_dry_run_transaction:
|
| 633 |
+
transaction, _log_path = create_repair_transaction(
|
| 634 |
+
source_path, accepted_fixes, source_bytes
|
| 635 |
+
)
|
| 636 |
+
txn_id = transaction.txn_id
|
| 637 |
+
reason = (
|
| 638 |
+
"Dry run completed without mutating the source file."
|
| 639 |
+
if accepted_fixes
|
| 640 |
+
else "No accepted fixes were produced."
|
| 641 |
+
)
|
| 642 |
+
elif accepted_fixes:
|
| 643 |
+
reason = "Dry run completed without mutating the source file."
|
| 644 |
+
|
| 645 |
+
if txn_id is not None and transaction is None:
|
| 646 |
+
# Replaying the log is unnecessary for the public contract here; this
|
| 647 |
+
# minimal receipt is intentionally enough for API callers.
|
| 648 |
+
transaction = None
|
| 649 |
+
|
| 650 |
+
receipt = RepairReceipt(
|
| 651 |
+
mode=request.mode,
|
| 652 |
+
applied=applied,
|
| 653 |
+
reversible=True,
|
| 654 |
+
source_path=str(source_path),
|
| 655 |
+
source_sha256=sha256_bytes(source_bytes),
|
| 656 |
+
post_sha256=post_sha256,
|
| 657 |
+
txn_id=txn_id,
|
| 658 |
+
allowed_columns=column_names(df),
|
| 659 |
+
valid_rows=list(range(row_count(df))),
|
| 660 |
+
issues_count=len(issues),
|
| 661 |
+
fixes_count=len(accepted_fixes),
|
| 662 |
+
reason=reason,
|
| 663 |
+
)
|
| 664 |
+
return RepairPipelineResult(
|
| 665 |
+
receipt=receipt,
|
| 666 |
+
issues=issues,
|
| 667 |
+
fixes=_verified_fixes(accepted_fixes, attempt_groups),
|
| 668 |
+
failures=failures,
|
| 669 |
+
transaction=transaction,
|
| 670 |
+
)
|
dataforge/env/__init__.py
CHANGED
|
@@ -1 +1,22 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DataForge RL environment — OpenEnv-compatible data-quality environment.
|
| 2 |
+
|
| 3 |
+
Public API:
|
| 4 |
+
DataForgeEnv — Core environment with reset/step/state/close.
|
| 5 |
+
ResetResult — Return type of reset().
|
| 6 |
+
StepResult — Return type of step().
|
| 7 |
+
EnvState — State snapshot from state().
|
| 8 |
+
DataForgeObservation — Agent-visible observation.
|
| 9 |
+
ToolResult — Structured result from each action.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from dataforge.env.environment import DataForgeEnv, EnvState, ResetResult, StepResult
|
| 13 |
+
from dataforge.env.observation import DataForgeObservation, ToolResult
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"DataForgeEnv",
|
| 17 |
+
"DataForgeObservation",
|
| 18 |
+
"EnvState",
|
| 19 |
+
"ResetResult",
|
| 20 |
+
"StepResult",
|
| 21 |
+
"ToolResult",
|
| 22 |
+
]
|
dataforge/env/environment.py
ADDED
|
@@ -0,0 +1,884 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenEnv-compatible DataForge RL environment.
|
| 2 |
+
|
| 3 |
+
Core environment implementing reset/step/state/close for data-quality
|
| 4 |
+
detection, diagnosis, and repair with typed tool-use actions.
|
| 5 |
+
|
| 6 |
+
No LLM calls. No disk writes. Dataset state is in-memory per episode.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import random
|
| 13 |
+
import re
|
| 14 |
+
import uuid
|
| 15 |
+
from difflib import SequenceMatcher
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Any, cast
|
| 18 |
+
|
| 19 |
+
import duckdb
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import sqlglot
|
| 22 |
+
import sqlglot.expressions as sqlglot_exp
|
| 23 |
+
from pydantic import BaseModel, Field
|
| 24 |
+
|
| 25 |
+
from dataforge.agent.scratchpad import Scratchpad
|
| 26 |
+
from dataforge.agent.tool_actions import (
|
| 27 |
+
Action,
|
| 28 |
+
Diagnose,
|
| 29 |
+
Fix,
|
| 30 |
+
Hypothesis,
|
| 31 |
+
InspectRows,
|
| 32 |
+
PatternMatch,
|
| 33 |
+
RootCause,
|
| 34 |
+
SqlQuery,
|
| 35 |
+
StatTest,
|
| 36 |
+
parse_action,
|
| 37 |
+
)
|
| 38 |
+
from dataforge.detectors import run_all_detectors
|
| 39 |
+
from dataforge.detectors.base import Issue
|
| 40 |
+
from dataforge.env.observation import DataForgeObservation, ToolResult
|
| 41 |
+
from dataforge.env.reward import (
|
| 42 |
+
P_FALSE_POS,
|
| 43 |
+
P_INVALID,
|
| 44 |
+
P_WRONG_FIX,
|
| 45 |
+
R_EXPLORE,
|
| 46 |
+
R_ROOT_CAUSE,
|
| 47 |
+
EpisodeMetrics,
|
| 48 |
+
RewardEngine,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
logger = logging.getLogger("dataforge.env")
|
| 52 |
+
|
| 53 |
+
__all__ = [
|
| 54 |
+
"DataForgeEnv",
|
| 55 |
+
"EnvState",
|
| 56 |
+
"ResetResult",
|
| 57 |
+
"StepResult",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
_FIXTURES_DIR = Path(__file__).resolve().parents[1].parent / "fixtures"
|
| 61 |
+
_DEFAULT_CSV = _FIXTURES_DIR / "hospital_10rows.csv"
|
| 62 |
+
_DEFAULT_SCHEMA = _FIXTURES_DIR / "hospital_schema.yaml"
|
| 63 |
+
_MAX_STEPS = 30
|
| 64 |
+
_MAX_RESULT_ROWS = 20
|
| 65 |
+
_TOOL_HISTORY_LIMIT = 5
|
| 66 |
+
_NOISE_EPSILON = 0.15
|
| 67 |
+
_BLOCKED_SQL_FRAGMENTS = (
|
| 68 |
+
"attach",
|
| 69 |
+
"call ",
|
| 70 |
+
"copy ",
|
| 71 |
+
"detach",
|
| 72 |
+
"duckdb_extensions",
|
| 73 |
+
"filename",
|
| 74 |
+
"from_csv_auto",
|
| 75 |
+
"glob(",
|
| 76 |
+
"http://",
|
| 77 |
+
"https://",
|
| 78 |
+
"httpfs",
|
| 79 |
+
"install",
|
| 80 |
+
"load ",
|
| 81 |
+
"mysql_scan",
|
| 82 |
+
"parquet_scan",
|
| 83 |
+
"postgres_scan",
|
| 84 |
+
"pragma",
|
| 85 |
+
"read_csv",
|
| 86 |
+
"read_json",
|
| 87 |
+
"read_parquet",
|
| 88 |
+
"s3://",
|
| 89 |
+
"sqlite_scan",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 94 |
+
# Result models
|
| 95 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ResetResult(BaseModel):
|
| 99 |
+
"""Result of env.reset()."""
|
| 100 |
+
|
| 101 |
+
observation: DataForgeObservation
|
| 102 |
+
info: dict[str, Any] = Field(default_factory=dict)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class StepResult(BaseModel):
|
| 106 |
+
"""Result of env.step()."""
|
| 107 |
+
|
| 108 |
+
observation: DataForgeObservation
|
| 109 |
+
reward: float = 0.0
|
| 110 |
+
done: bool = False
|
| 111 |
+
info: dict[str, Any] = Field(default_factory=dict)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class EnvState(BaseModel):
|
| 115 |
+
"""Internal environment state snapshot."""
|
| 116 |
+
|
| 117 |
+
episode_id: str = ""
|
| 118 |
+
step_count: int = 0
|
| 119 |
+
task_id: str = ""
|
| 120 |
+
issues_detected: int = 0
|
| 121 |
+
issues_fixed: int = 0
|
| 122 |
+
false_positives: int = 0
|
| 123 |
+
total_issues: int = 0
|
| 124 |
+
is_done: bool = False
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 128 |
+
# Environment
|
| 129 |
+
# ═══════════════════════════════════════════════════════════════════════════
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class DataForgeEnv:
|
| 133 |
+
"""OpenEnv-compatible RL environment for data quality repair.
|
| 134 |
+
|
| 135 |
+
Core API: ``reset()``, ``step()``, ``state()``, ``close()`` (no-op).
|
| 136 |
+
|
| 137 |
+
Example::
|
| 138 |
+
|
| 139 |
+
>>> env = DataForgeEnv()
|
| 140 |
+
>>> result = env.reset(seed=42)
|
| 141 |
+
>>> result.observation.done
|
| 142 |
+
False
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, max_steps: int = _MAX_STEPS) -> None:
|
| 146 |
+
self._max_steps = max_steps
|
| 147 |
+
self._episode_id = ""
|
| 148 |
+
self._step_count = 0
|
| 149 |
+
self._df: pd.DataFrame = pd.DataFrame()
|
| 150 |
+
self._ground_truth: list[Issue] = []
|
| 151 |
+
self._found_issues: list[dict[str, Any]] = []
|
| 152 |
+
self._fixed_issues: list[dict[str, Any]] = []
|
| 153 |
+
self._false_positives = 0
|
| 154 |
+
self._cumulative_reward = 0.0
|
| 155 |
+
self._is_done = False
|
| 156 |
+
self._inspected_rows: set[int] = set()
|
| 157 |
+
self._noisy = False
|
| 158 |
+
self._noise_rng: random.Random | None = None
|
| 159 |
+
self._scratchpad = Scratchpad()
|
| 160 |
+
self._tool_history: list[ToolResult] = []
|
| 161 |
+
self._reward_engine = RewardEngine()
|
| 162 |
+
self._schema_info: dict[str, str] = {}
|
| 163 |
+
self._causal_dag_cache: Any = None
|
| 164 |
+
self._root_cause_labels: set[int] = set()
|
| 165 |
+
|
| 166 |
+
# ── Core API ──────────────────────────────────────────────────────────
|
| 167 |
+
|
| 168 |
+
def reset(self, seed: int | None = None, *, noisy: bool = False) -> ResetResult:
|
| 169 |
+
"""Reset the environment for a new episode.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
seed: Optional RNG seed for deterministic episodes.
|
| 173 |
+
noisy: If True, enable observation noise (epsilon=0.15).
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
ResetResult with initial observation.
|
| 177 |
+
"""
|
| 178 |
+
self._episode_id = str(uuid.uuid4())
|
| 179 |
+
self._step_count = 0
|
| 180 |
+
self._found_issues = []
|
| 181 |
+
self._fixed_issues = []
|
| 182 |
+
self._false_positives = 0
|
| 183 |
+
self._cumulative_reward = 0.0
|
| 184 |
+
self._is_done = False
|
| 185 |
+
self._inspected_rows = set()
|
| 186 |
+
self._scratchpad.reset()
|
| 187 |
+
self._tool_history = []
|
| 188 |
+
self._causal_dag_cache = None
|
| 189 |
+
self._root_cause_labels = set()
|
| 190 |
+
self._noisy = noisy
|
| 191 |
+
self._noise_rng = random.Random(seed if seed is not None else 0) if noisy else None
|
| 192 |
+
|
| 193 |
+
# Load fixture dataset
|
| 194 |
+
self._df = pd.read_csv(_DEFAULT_CSV, dtype=str)
|
| 195 |
+
self._schema_info = dict.fromkeys(self._df.columns, "str")
|
| 196 |
+
if _DEFAULT_SCHEMA.exists():
|
| 197 |
+
import yaml
|
| 198 |
+
|
| 199 |
+
with open(_DEFAULT_SCHEMA, encoding="utf-8") as f:
|
| 200 |
+
schema_data = yaml.safe_load(f)
|
| 201 |
+
if schema_data and "columns" in schema_data:
|
| 202 |
+
self._schema_info = schema_data["columns"]
|
| 203 |
+
|
| 204 |
+
# Run detectors for hidden ground truth
|
| 205 |
+
self._ground_truth = run_all_detectors(self._df)
|
| 206 |
+
logger.info(
|
| 207 |
+
"Episode %s: %d rows, %d ground-truth issues",
|
| 208 |
+
self._episode_id[:8],
|
| 209 |
+
len(self._df),
|
| 210 |
+
len(self._ground_truth),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Initial observation with first 5 rows
|
| 214 |
+
initial_rows = cast(list[dict[str, Any]], self._df.head(5).to_dict(orient="records"))
|
| 215 |
+
obs = DataForgeObservation(
|
| 216 |
+
visible_rows=initial_rows,
|
| 217 |
+
step_budget_remaining=self._max_steps,
|
| 218 |
+
scratchpad_summary=self._scratchpad.summary(),
|
| 219 |
+
metadata={
|
| 220 |
+
"episode_id": self._episode_id,
|
| 221 |
+
"total_rows": len(self._df),
|
| 222 |
+
"total_columns": len(self._df.columns),
|
| 223 |
+
"schema": self._schema_info,
|
| 224 |
+
},
|
| 225 |
+
)
|
| 226 |
+
return ResetResult(observation=obs, info={"episode_id": self._episode_id})
|
| 227 |
+
|
| 228 |
+
def step(self, action: Action | dict[str, Any]) -> StepResult:
|
| 229 |
+
"""Execute one agent action and return the result.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
action: A typed Action model or raw dict to be parsed.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
StepResult with observation, reward, and done flag.
|
| 236 |
+
"""
|
| 237 |
+
if self._is_done:
|
| 238 |
+
return self._terminal_result(0.0)
|
| 239 |
+
|
| 240 |
+
self._step_count += 1
|
| 241 |
+
|
| 242 |
+
# Parse if raw dict
|
| 243 |
+
if isinstance(action, dict):
|
| 244 |
+
try:
|
| 245 |
+
action = parse_action(action)
|
| 246 |
+
except Exception as exc:
|
| 247 |
+
return self._error_step(str(exc))
|
| 248 |
+
|
| 249 |
+
# Dispatch
|
| 250 |
+
try:
|
| 251 |
+
tool_result, reward = self._dispatch(action)
|
| 252 |
+
except Exception as exc:
|
| 253 |
+
logger.exception("Action dispatch error at step %d", self._step_count)
|
| 254 |
+
return self._error_step(str(exc))
|
| 255 |
+
|
| 256 |
+
# Late-step penalty
|
| 257 |
+
reward += self._reward_engine.compute_late_penalty(self._step_count, self._max_steps)
|
| 258 |
+
|
| 259 |
+
# Accumulate
|
| 260 |
+
self._cumulative_reward += reward
|
| 261 |
+
|
| 262 |
+
# Record in history
|
| 263 |
+
self._tool_history.append(tool_result)
|
| 264 |
+
if len(self._tool_history) > _TOOL_HISTORY_LIMIT:
|
| 265 |
+
self._tool_history = self._tool_history[-_TOOL_HISTORY_LIMIT:]
|
| 266 |
+
|
| 267 |
+
# Check termination
|
| 268 |
+
done = self._step_count >= self._max_steps
|
| 269 |
+
if done:
|
| 270 |
+
self._is_done = True
|
| 271 |
+
terminal = self._compute_terminal()
|
| 272 |
+
self._cumulative_reward = max(self._cumulative_reward, terminal)
|
| 273 |
+
|
| 274 |
+
obs = DataForgeObservation(
|
| 275 |
+
visible_rows=tool_result.data
|
| 276 |
+
if tool_result.action_type == "INSPECT_ROWS" and tool_result.success
|
| 277 |
+
else None,
|
| 278 |
+
scratchpad_summary=self._scratchpad.summary(),
|
| 279 |
+
step_budget_remaining=max(0, self._max_steps - self._step_count),
|
| 280 |
+
tool_usage_history=list(self._tool_history),
|
| 281 |
+
latest_result=tool_result,
|
| 282 |
+
done=done,
|
| 283 |
+
reward=reward,
|
| 284 |
+
cumulative_reward=self._cumulative_reward,
|
| 285 |
+
)
|
| 286 |
+
return StepResult(observation=obs, reward=reward, done=done)
|
| 287 |
+
|
| 288 |
+
def state(self) -> EnvState:
|
| 289 |
+
"""Return current internal state snapshot."""
|
| 290 |
+
return EnvState(
|
| 291 |
+
episode_id=self._episode_id,
|
| 292 |
+
step_count=self._step_count,
|
| 293 |
+
issues_detected=len(self._found_issues),
|
| 294 |
+
issues_fixed=len(self._fixed_issues),
|
| 295 |
+
false_positives=self._false_positives,
|
| 296 |
+
total_issues=len(self._ground_truth),
|
| 297 |
+
is_done=self._is_done,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def close(self) -> None:
|
| 301 |
+
"""No-op. Retained for OpenEnv container compatibility."""
|
| 302 |
+
|
| 303 |
+
# ── Dispatch ─────────────────��────────────────────────────────────────
|
| 304 |
+
|
| 305 |
+
def _dispatch(self, action: Action) -> tuple[ToolResult, float]:
|
| 306 |
+
"""Route action to handler. Returns (tool_result, step_reward)."""
|
| 307 |
+
if isinstance(action, InspectRows):
|
| 308 |
+
return self._handle_inspect(action)
|
| 309 |
+
if isinstance(action, SqlQuery):
|
| 310 |
+
return self._handle_sql(action)
|
| 311 |
+
if isinstance(action, StatTest):
|
| 312 |
+
return self._handle_stat(action)
|
| 313 |
+
if isinstance(action, PatternMatch):
|
| 314 |
+
return self._handle_pattern(action)
|
| 315 |
+
if isinstance(action, Hypothesis):
|
| 316 |
+
return self._handle_hypothesis(action)
|
| 317 |
+
if isinstance(action, RootCause):
|
| 318 |
+
return self._handle_root_cause(action)
|
| 319 |
+
if isinstance(action, Diagnose):
|
| 320 |
+
return self._handle_diagnose(action)
|
| 321 |
+
if isinstance(action, Fix):
|
| 322 |
+
return self._handle_fix(action)
|
| 323 |
+
return ToolResult(
|
| 324 |
+
action_type="UNKNOWN",
|
| 325 |
+
success=False,
|
| 326 |
+
error={"verdict": "error", "reason": "Unknown action type"},
|
| 327 |
+
), P_INVALID
|
| 328 |
+
|
| 329 |
+
# ── Action handlers ───────────────────────────────────────────────────
|
| 330 |
+
|
| 331 |
+
def _handle_inspect(self, action: InspectRows) -> tuple[ToolResult, float]:
|
| 332 |
+
"""Handle INSPECT_ROWS: return dataset rows."""
|
| 333 |
+
valid_indices = [i for i in action.row_indices if 0 <= i < len(self._df)]
|
| 334 |
+
if not valid_indices:
|
| 335 |
+
return ToolResult(
|
| 336 |
+
action_type="INSPECT_ROWS",
|
| 337 |
+
success=False,
|
| 338 |
+
error={"verdict": "error", "reason": "No valid row indices"},
|
| 339 |
+
), P_INVALID
|
| 340 |
+
|
| 341 |
+
# Apply 20-row cap
|
| 342 |
+
valid_indices = valid_indices[:20]
|
| 343 |
+
rows = self._df.iloc[valid_indices]
|
| 344 |
+
if action.column_names:
|
| 345 |
+
valid_cols = [c for c in action.column_names if c in self._df.columns]
|
| 346 |
+
if valid_cols:
|
| 347 |
+
rows = rows[valid_cols]
|
| 348 |
+
|
| 349 |
+
row_dicts = cast(list[dict[str, Any]], rows.to_dict(orient="records"))
|
| 350 |
+
for i, idx in enumerate(valid_indices[: len(row_dicts)]):
|
| 351 |
+
row_dicts[i]["_row_index"] = idx
|
| 352 |
+
|
| 353 |
+
# Noise injection
|
| 354 |
+
if self._noisy and self._noise_rng:
|
| 355 |
+
row_dicts = self._inject_noise(row_dicts)
|
| 356 |
+
|
| 357 |
+
# Exploration bonus
|
| 358 |
+
new_indices = set(valid_indices) - self._inspected_rows
|
| 359 |
+
self._inspected_rows.update(valid_indices)
|
| 360 |
+
gt_rows = {issue.row for issue in self._ground_truth}
|
| 361 |
+
found_rows = {f["row"] for f in self._found_issues}
|
| 362 |
+
bonus = self._reward_engine.compute_exploration_bonus(
|
| 363 |
+
new_indices,
|
| 364 |
+
self._inspected_rows,
|
| 365 |
+
len(self._df),
|
| 366 |
+
gt_rows,
|
| 367 |
+
found_rows,
|
| 368 |
+
)
|
| 369 |
+
return ToolResult(action_type="INSPECT_ROWS", success=True, data=row_dicts), bonus
|
| 370 |
+
|
| 371 |
+
def _handle_sql(self, action: SqlQuery) -> tuple[ToolResult, float]:
|
| 372 |
+
"""Handle SQL_QUERY: execute read-only SQL via DuckDB."""
|
| 373 |
+
# Validate read-only
|
| 374 |
+
try:
|
| 375 |
+
parsed = [stmt for stmt in sqlglot.parse(action.query) if stmt is not None]
|
| 376 |
+
except sqlglot.errors.ParseError as exc:
|
| 377 |
+
return ToolResult(
|
| 378 |
+
action_type="SQL_QUERY",
|
| 379 |
+
success=False,
|
| 380 |
+
error={
|
| 381 |
+
"verdict": "error",
|
| 382 |
+
"reason": str(exc),
|
| 383 |
+
"suggested_constraint": "Use valid SQL syntax",
|
| 384 |
+
},
|
| 385 |
+
), P_INVALID
|
| 386 |
+
|
| 387 |
+
if len(parsed) != 1:
|
| 388 |
+
return ToolResult(
|
| 389 |
+
action_type="SQL_QUERY",
|
| 390 |
+
success=False,
|
| 391 |
+
error={
|
| 392 |
+
"verdict": "rejected",
|
| 393 |
+
"reason": "Exactly one SELECT statement is allowed.",
|
| 394 |
+
"suggested_constraint": "Use a single read-only SELECT statement.",
|
| 395 |
+
},
|
| 396 |
+
), P_INVALID
|
| 397 |
+
|
| 398 |
+
normalized_query = f" {action.query.lower()} "
|
| 399 |
+
blocked = next(
|
| 400 |
+
(fragment for fragment in _BLOCKED_SQL_FRAGMENTS if fragment in normalized_query),
|
| 401 |
+
None,
|
| 402 |
+
)
|
| 403 |
+
if blocked is not None:
|
| 404 |
+
return ToolResult(
|
| 405 |
+
action_type="SQL_QUERY",
|
| 406 |
+
success=False,
|
| 407 |
+
error={
|
| 408 |
+
"verdict": "rejected",
|
| 409 |
+
"reason": "SQL_QUERY may only read from the registered data relation.",
|
| 410 |
+
"suggested_constraint": "Query the in-memory data table without file, network, extension, or table functions.",
|
| 411 |
+
},
|
| 412 |
+
), P_INVALID
|
| 413 |
+
|
| 414 |
+
for stmt in parsed:
|
| 415 |
+
if stmt.key not in ("select",):
|
| 416 |
+
return ToolResult(
|
| 417 |
+
action_type="SQL_QUERY",
|
| 418 |
+
success=False,
|
| 419 |
+
error={
|
| 420 |
+
"verdict": "rejected",
|
| 421 |
+
"reason": f"Only SELECT queries allowed, got {stmt.key.upper()}",
|
| 422 |
+
"suggested_constraint": "Use SELECT statements only",
|
| 423 |
+
},
|
| 424 |
+
), P_INVALID
|
| 425 |
+
|
| 426 |
+
for table in stmt.find_all(sqlglot_exp.Table):
|
| 427 |
+
if table.name.lower() != "data":
|
| 428 |
+
return ToolResult(
|
| 429 |
+
action_type="SQL_QUERY",
|
| 430 |
+
success=False,
|
| 431 |
+
error={
|
| 432 |
+
"verdict": "rejected",
|
| 433 |
+
"reason": (
|
| 434 |
+
"SQL_QUERY may only reference the registered data relation; "
|
| 435 |
+
f"got '{table.name}'."
|
| 436 |
+
),
|
| 437 |
+
"suggested_constraint": "Use FROM data for tabular queries.",
|
| 438 |
+
},
|
| 439 |
+
), P_INVALID
|
| 440 |
+
|
| 441 |
+
try:
|
| 442 |
+
conn = duckdb.connect(":memory:")
|
| 443 |
+
conn.register("data", self._df)
|
| 444 |
+
result_df = conn.execute(action.query).fetchdf()
|
| 445 |
+
conn.close()
|
| 446 |
+
rows = result_df.head(_MAX_RESULT_ROWS).to_dict(orient="records")
|
| 447 |
+
return ToolResult(action_type="SQL_QUERY", success=True, data=rows), 0.0
|
| 448 |
+
except duckdb.Error as exc:
|
| 449 |
+
return ToolResult(
|
| 450 |
+
action_type="SQL_QUERY",
|
| 451 |
+
success=False,
|
| 452 |
+
error={"verdict": "error", "reason": str(exc)},
|
| 453 |
+
), P_INVALID
|
| 454 |
+
|
| 455 |
+
def _handle_stat(self, action: StatTest) -> tuple[ToolResult, float]:
|
| 456 |
+
"""Handle STAT_TEST: run zscore/iqr/ks on a column."""
|
| 457 |
+
if action.column not in self._df.columns:
|
| 458 |
+
return ToolResult(
|
| 459 |
+
action_type="STAT_TEST",
|
| 460 |
+
success=False,
|
| 461 |
+
error={"verdict": "error", "reason": f"Column '{action.column}' not found"},
|
| 462 |
+
), P_INVALID
|
| 463 |
+
|
| 464 |
+
try:
|
| 465 |
+
col = pd.to_numeric(self._df[action.column], errors="coerce").dropna()
|
| 466 |
+
if len(col) == 0:
|
| 467 |
+
return ToolResult(
|
| 468 |
+
action_type="STAT_TEST",
|
| 469 |
+
success=False,
|
| 470 |
+
error={
|
| 471 |
+
"verdict": "error",
|
| 472 |
+
"reason": f"No numeric values in column '{action.column}'",
|
| 473 |
+
},
|
| 474 |
+
), P_INVALID
|
| 475 |
+
except Exception as exc:
|
| 476 |
+
return ToolResult(
|
| 477 |
+
action_type="STAT_TEST",
|
| 478 |
+
success=False,
|
| 479 |
+
error={"verdict": "error", "reason": str(exc)},
|
| 480 |
+
), P_INVALID
|
| 481 |
+
|
| 482 |
+
from scipy import stats as scipy_stats # type: ignore[import-untyped]
|
| 483 |
+
|
| 484 |
+
if action.test_type == "zscore":
|
| 485 |
+
zscores = scipy_stats.zscore(col)
|
| 486 |
+
threshold = action.threshold or 3.0
|
| 487 |
+
outliers = col.index[abs(zscores) > threshold].tolist()
|
| 488 |
+
data = {
|
| 489 |
+
"test": "zscore",
|
| 490 |
+
"threshold": threshold,
|
| 491 |
+
"outlier_indices": outliers,
|
| 492 |
+
"n_outliers": len(outliers),
|
| 493 |
+
"mean": float(col.mean()),
|
| 494 |
+
"std": float(col.std()),
|
| 495 |
+
}
|
| 496 |
+
elif action.test_type == "iqr":
|
| 497 |
+
q1, q3 = float(col.quantile(0.25)), float(col.quantile(0.75))
|
| 498 |
+
iqr_val = q3 - q1
|
| 499 |
+
factor = action.threshold or 1.5
|
| 500 |
+
lower, upper = q1 - factor * iqr_val, q3 + factor * iqr_val
|
| 501 |
+
outliers = col.index[(col < lower) | (col > upper)].tolist()
|
| 502 |
+
data = {
|
| 503 |
+
"test": "iqr",
|
| 504 |
+
"q1": q1,
|
| 505 |
+
"q3": q3,
|
| 506 |
+
"iqr": iqr_val,
|
| 507 |
+
"lower": lower,
|
| 508 |
+
"upper": upper,
|
| 509 |
+
"outlier_indices": outliers,
|
| 510 |
+
}
|
| 511 |
+
elif action.test_type == "ks":
|
| 512 |
+
stat_val, p_val = scipy_stats.kstest(
|
| 513 |
+
col, "norm", args=(float(col.mean()), float(col.std()))
|
| 514 |
+
)
|
| 515 |
+
data = {
|
| 516 |
+
"test": "ks",
|
| 517 |
+
"statistic": float(stat_val),
|
| 518 |
+
"p_value": float(p_val),
|
| 519 |
+
"normal": p_val > 0.05,
|
| 520 |
+
}
|
| 521 |
+
else:
|
| 522 |
+
return ToolResult(
|
| 523 |
+
action_type="STAT_TEST",
|
| 524 |
+
success=False,
|
| 525 |
+
error={"verdict": "error", "reason": f"Unknown test type: {action.test_type}"},
|
| 526 |
+
), P_INVALID
|
| 527 |
+
|
| 528 |
+
return ToolResult(action_type="STAT_TEST", success=True, data=data), 0.0
|
| 529 |
+
|
| 530 |
+
def _handle_pattern(self, action: PatternMatch) -> tuple[ToolResult, float]:
|
| 531 |
+
"""Handle PATTERN_MATCH: evaluate regex against column values."""
|
| 532 |
+
if action.column not in self._df.columns:
|
| 533 |
+
return ToolResult(
|
| 534 |
+
action_type="PATTERN_MATCH",
|
| 535 |
+
success=False,
|
| 536 |
+
error={"verdict": "error", "reason": f"Column '{action.column}' not found"},
|
| 537 |
+
), P_INVALID
|
| 538 |
+
|
| 539 |
+
try:
|
| 540 |
+
compiled = re.compile(action.pattern)
|
| 541 |
+
except re.error as exc:
|
| 542 |
+
return ToolResult(
|
| 543 |
+
action_type="PATTERN_MATCH",
|
| 544 |
+
success=False,
|
| 545 |
+
error={"verdict": "error", "reason": f"Invalid regex: {exc}"},
|
| 546 |
+
), P_INVALID
|
| 547 |
+
|
| 548 |
+
matches: list[dict[str, Any]] = []
|
| 549 |
+
for idx, val in enumerate(self._df[action.column].astype(str)):
|
| 550 |
+
is_match = bool(compiled.search(val))
|
| 551 |
+
if is_match == action.expect_match:
|
| 552 |
+
matches.append({"row": idx, "column": action.column, "value": val})
|
| 553 |
+
return ToolResult(
|
| 554 |
+
action_type="PATTERN_MATCH",
|
| 555 |
+
success=True,
|
| 556 |
+
data={"matches": matches[:_MAX_RESULT_ROWS], "total_matches": len(matches)},
|
| 557 |
+
), 0.0
|
| 558 |
+
|
| 559 |
+
def _handle_hypothesis(self, action: Hypothesis) -> tuple[ToolResult, float]:
|
| 560 |
+
"""Handle HYPOTHESIS: record claim and award root-cause credit."""
|
| 561 |
+
self._scratchpad.add_hypothesis(
|
| 562 |
+
action.claim,
|
| 563 |
+
action.affected_rows,
|
| 564 |
+
action.affected_columns,
|
| 565 |
+
action.root_cause_type,
|
| 566 |
+
)
|
| 567 |
+
# Check for root-cause match against ground truth
|
| 568 |
+
credit = 0.0
|
| 569 |
+
for issue in self._ground_truth:
|
| 570 |
+
if (
|
| 571 |
+
issue.row in action.affected_rows
|
| 572 |
+
and issue.column in action.affected_columns
|
| 573 |
+
and issue.issue_type == action.root_cause_type
|
| 574 |
+
):
|
| 575 |
+
credit += R_EXPLORE
|
| 576 |
+
data = {"recorded": True, "root_cause_credit": credit}
|
| 577 |
+
return ToolResult(action_type="HYPOTHESIS", success=True, data=data), credit
|
| 578 |
+
|
| 579 |
+
def _handle_root_cause(self, action: RootCause) -> tuple[ToolResult, float]:
|
| 580 |
+
"""Handle ROOT_CAUSE: analyze detected issues for minimal roots."""
|
| 581 |
+
if not self._found_issues:
|
| 582 |
+
return ToolResult(
|
| 583 |
+
action_type="ROOT_CAUSE",
|
| 584 |
+
success=False,
|
| 585 |
+
error={"verdict": "error", "reason": "No detected issues are available"},
|
| 586 |
+
), P_INVALID
|
| 587 |
+
|
| 588 |
+
invalid = [idx for idx in action.error_indices if idx >= len(self._found_issues)]
|
| 589 |
+
if invalid:
|
| 590 |
+
return ToolResult(
|
| 591 |
+
action_type="ROOT_CAUSE",
|
| 592 |
+
success=False,
|
| 593 |
+
error={
|
| 594 |
+
"verdict": "error",
|
| 595 |
+
"reason": f"Detected issue indices out of range: {invalid}",
|
| 596 |
+
},
|
| 597 |
+
), P_INVALID
|
| 598 |
+
|
| 599 |
+
from dataforge.causal.pc import discover_causal_dag
|
| 600 |
+
from dataforge.causal.root_cause import CausalRootCauseAnalyzer, evidence_from_issue
|
| 601 |
+
|
| 602 |
+
if self._causal_dag_cache is None:
|
| 603 |
+
self._causal_dag_cache = discover_causal_dag(self._df).dag
|
| 604 |
+
|
| 605 |
+
selected = [
|
| 606 |
+
evidence_from_issue(index, self._found_issues[index]) for index in action.error_indices
|
| 607 |
+
]
|
| 608 |
+
result = CausalRootCauseAnalyzer(self._causal_dag_cache).analyze(selected)
|
| 609 |
+
data = result.model_dump(mode="json")
|
| 610 |
+
reward = self._root_cause_reward(set(result.root_indices))
|
| 611 |
+
return ToolResult(action_type="ROOT_CAUSE", success=True, data=data), reward
|
| 612 |
+
|
| 613 |
+
def _handle_diagnose(self, action: Diagnose) -> tuple[ToolResult, float]:
|
| 614 |
+
"""Handle DIAGNOSE: score against ground truth."""
|
| 615 |
+
if action.row < 0 or action.row >= len(self._df):
|
| 616 |
+
return ToolResult(
|
| 617 |
+
action_type="DIAGNOSE",
|
| 618 |
+
success=False,
|
| 619 |
+
error={"verdict": "error", "reason": f"Row {action.row} out of bounds"},
|
| 620 |
+
), P_INVALID
|
| 621 |
+
if action.column not in self._df.columns:
|
| 622 |
+
return ToolResult(
|
| 623 |
+
action_type="DIAGNOSE",
|
| 624 |
+
success=False,
|
| 625 |
+
error={"verdict": "error", "reason": f"Column '{action.column}' not found"},
|
| 626 |
+
), P_INVALID
|
| 627 |
+
|
| 628 |
+
# Already reported?
|
| 629 |
+
for found in self._found_issues:
|
| 630 |
+
if found["row"] == action.row and found["column"] == action.column:
|
| 631 |
+
return ToolResult(
|
| 632 |
+
action_type="DIAGNOSE", success=True, data={"result": "already_found"}
|
| 633 |
+
), 0.0
|
| 634 |
+
|
| 635 |
+
# Match ground truth
|
| 636 |
+
for issue in self._ground_truth:
|
| 637 |
+
if issue.row == action.row and issue.column == action.column:
|
| 638 |
+
type_match = action.issue_type == issue.issue_type
|
| 639 |
+
reward = self._reward_engine.diagnose_reward(type_match)
|
| 640 |
+
self._found_issues.append(
|
| 641 |
+
{"row": action.row, "column": action.column, "type": action.issue_type}
|
| 642 |
+
)
|
| 643 |
+
self._scratchpad.confirm_issue(action.row, action.column, action.issue_type)
|
| 644 |
+
return ToolResult(
|
| 645 |
+
action_type="DIAGNOSE",
|
| 646 |
+
success=True,
|
| 647 |
+
data={"result": "correct", "type_match": type_match},
|
| 648 |
+
), reward
|
| 649 |
+
|
| 650 |
+
# False positive
|
| 651 |
+
self._false_positives += 1
|
| 652 |
+
return ToolResult(
|
| 653 |
+
action_type="DIAGNOSE", success=True, data={"result": "false_positive"}
|
| 654 |
+
), P_FALSE_POS
|
| 655 |
+
|
| 656 |
+
def _root_cause_reward(self, root_indices: set[int]) -> float:
|
| 657 |
+
"""Return root-cause bonus only when task labels are available."""
|
| 658 |
+
if not self._root_cause_labels:
|
| 659 |
+
return 0.0
|
| 660 |
+
return R_ROOT_CAUSE if root_indices == self._root_cause_labels else 0.0
|
| 661 |
+
|
| 662 |
+
def _handle_fix(self, action: Fix) -> tuple[ToolResult, float]:
|
| 663 |
+
"""Handle FIX: validate through safety/SMT, then score."""
|
| 664 |
+
if action.row < 0 or action.row >= len(self._df):
|
| 665 |
+
return ToolResult(
|
| 666 |
+
action_type="FIX",
|
| 667 |
+
success=False,
|
| 668 |
+
error={"verdict": "error", "reason": f"Row {action.row} out of bounds"},
|
| 669 |
+
), P_INVALID
|
| 670 |
+
if action.column not in self._df.columns:
|
| 671 |
+
return ToolResult(
|
| 672 |
+
action_type="FIX",
|
| 673 |
+
success=False,
|
| 674 |
+
error={"verdict": "error", "reason": f"Column '{action.column}' not found"},
|
| 675 |
+
), P_INVALID
|
| 676 |
+
|
| 677 |
+
# Already fixed?
|
| 678 |
+
for fixed in self._fixed_issues:
|
| 679 |
+
if fixed["row"] == action.row and fixed["column"] == action.column:
|
| 680 |
+
return ToolResult(
|
| 681 |
+
action_type="FIX", success=True, data={"result": "already_fixed"}
|
| 682 |
+
), 0.0
|
| 683 |
+
|
| 684 |
+
# Safety filter + SMT verifier (best-effort, no crash on import failure)
|
| 685 |
+
try:
|
| 686 |
+
safety_ok, safety_msg = self._check_safety(action)
|
| 687 |
+
except Exception as exc:
|
| 688 |
+
logger.warning("Safety pipeline failed closed: %s", exc)
|
| 689 |
+
safety_ok = False
|
| 690 |
+
safety_msg = f"Safety pipeline failed closed: {exc}"
|
| 691 |
+
if not safety_ok:
|
| 692 |
+
return ToolResult(
|
| 693 |
+
action_type="FIX",
|
| 694 |
+
success=False,
|
| 695 |
+
error={"verdict": "rejected", "reason": safety_msg},
|
| 696 |
+
), P_INVALID
|
| 697 |
+
|
| 698 |
+
# Match ground truth
|
| 699 |
+
for issue in self._ground_truth:
|
| 700 |
+
if issue.row == action.row and issue.column == action.column:
|
| 701 |
+
if issue.expected is None:
|
| 702 |
+
return ToolResult(
|
| 703 |
+
action_type="FIX", success=True, data={"result": "detection_only"}
|
| 704 |
+
), 0.0
|
| 705 |
+
|
| 706 |
+
# Exact match (case-insensitive)
|
| 707 |
+
if action.new_value.strip().lower() == str(issue.expected).lower():
|
| 708 |
+
reward = self._reward_engine.fix_reward(
|
| 709 |
+
exact=True, has_justification=bool(action.justification)
|
| 710 |
+
)
|
| 711 |
+
self._fixed_issues.append(
|
| 712 |
+
{"row": action.row, "column": action.column, "value": action.new_value}
|
| 713 |
+
)
|
| 714 |
+
self._auto_diagnose(action, issue)
|
| 715 |
+
return ToolResult(
|
| 716 |
+
action_type="FIX", success=True, data={"result": "correct"}
|
| 717 |
+
), reward
|
| 718 |
+
|
| 719 |
+
# Partial: numeric within 1%
|
| 720 |
+
try:
|
| 721 |
+
prov = float(action.new_value.strip())
|
| 722 |
+
exp = float(str(issue.expected))
|
| 723 |
+
rel_err = abs(prov - exp) / abs(exp) if exp != 0 else abs(prov)
|
| 724 |
+
if rel_err < 0.01:
|
| 725 |
+
reward = self._reward_engine.fix_reward(
|
| 726 |
+
exact=False, has_justification=bool(action.justification)
|
| 727 |
+
)
|
| 728 |
+
self._fixed_issues.append(
|
| 729 |
+
{"row": action.row, "column": action.column, "value": action.new_value}
|
| 730 |
+
)
|
| 731 |
+
self._auto_diagnose(action, issue)
|
| 732 |
+
return ToolResult(
|
| 733 |
+
action_type="FIX", success=True, data={"result": "partial_numeric"}
|
| 734 |
+
), reward
|
| 735 |
+
except (ValueError, TypeError):
|
| 736 |
+
pass
|
| 737 |
+
|
| 738 |
+
# Partial: string similarity >= 85%
|
| 739 |
+
sim = SequenceMatcher(
|
| 740 |
+
None, action.new_value.lower(), str(issue.expected).lower()
|
| 741 |
+
).ratio()
|
| 742 |
+
if sim >= 0.85:
|
| 743 |
+
reward = self._reward_engine.fix_reward(
|
| 744 |
+
exact=False, has_justification=bool(action.justification)
|
| 745 |
+
)
|
| 746 |
+
self._fixed_issues.append(
|
| 747 |
+
{"row": action.row, "column": action.column, "value": action.new_value}
|
| 748 |
+
)
|
| 749 |
+
self._auto_diagnose(action, issue)
|
| 750 |
+
return ToolResult(
|
| 751 |
+
action_type="FIX", success=True, data={"result": "partial_string"}
|
| 752 |
+
), reward
|
| 753 |
+
|
| 754 |
+
return ToolResult(
|
| 755 |
+
action_type="FIX", success=True, data={"result": "wrong_value"}
|
| 756 |
+
), P_WRONG_FIX
|
| 757 |
+
|
| 758 |
+
return ToolResult(
|
| 759 |
+
action_type="FIX", success=True, data={"result": "no_issue_at_location"}
|
| 760 |
+
), P_WRONG_FIX
|
| 761 |
+
|
| 762 |
+
# ── Helpers ────────────────────────────────────────────────────────────
|
| 763 |
+
|
| 764 |
+
def _check_safety(self, action: Fix) -> tuple[bool, str]:
|
| 765 |
+
"""Run SafetyFilter + SMTVerifier. Returns (ok, message)."""
|
| 766 |
+
try:
|
| 767 |
+
from dataforge.repairers.base import ProposedFix
|
| 768 |
+
from dataforge.safety.filter import SafetyContext, SafetyFilter, SafetyVerdict
|
| 769 |
+
from dataforge.transactions.txn import CellFix
|
| 770 |
+
from dataforge.verifier.smt import SMTVerifier, VerificationVerdict
|
| 771 |
+
|
| 772 |
+
old_val = str(self._df.at[action.row, action.column])
|
| 773 |
+
cell_fix = CellFix(
|
| 774 |
+
row=action.row,
|
| 775 |
+
column=action.column,
|
| 776 |
+
old_value=old_val,
|
| 777 |
+
new_value=action.new_value,
|
| 778 |
+
detector_id="agent",
|
| 779 |
+
)
|
| 780 |
+
proposed = ProposedFix(
|
| 781 |
+
fix=cell_fix,
|
| 782 |
+
reason=action.justification,
|
| 783 |
+
confidence=0.8,
|
| 784 |
+
provenance="deterministic",
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
sf = SafetyFilter()
|
| 788 |
+
ctx = SafetyContext()
|
| 789 |
+
sr = sf.evaluate(proposed, None, ctx)
|
| 790 |
+
if sr.verdict == SafetyVerdict.DENY:
|
| 791 |
+
return False, f"Safety filter denied: {sr.reason}"
|
| 792 |
+
|
| 793 |
+
verifier = SMTVerifier()
|
| 794 |
+
vr = verifier.verify(self._df, [proposed])
|
| 795 |
+
if vr.verdict == VerificationVerdict.REJECT:
|
| 796 |
+
return False, f"SMT verifier rejected: {vr.reason}"
|
| 797 |
+
if vr.verdict == VerificationVerdict.UNKNOWN:
|
| 798 |
+
return False, f"SMT verifier returned unknown: {vr.reason}"
|
| 799 |
+
|
| 800 |
+
return True, "Passed safety and verification"
|
| 801 |
+
except ImportError as exc:
|
| 802 |
+
return False, f"Safety/verifier dependency unavailable: {exc}"
|
| 803 |
+
|
| 804 |
+
def _auto_diagnose(self, action: Fix, issue: Issue) -> None:
|
| 805 |
+
"""Auto-credit diagnosis when agent fixes without diagnosing first."""
|
| 806 |
+
already = any(
|
| 807 |
+
f["row"] == action.row and f["column"] == action.column for f in self._found_issues
|
| 808 |
+
)
|
| 809 |
+
if not already:
|
| 810 |
+
self._found_issues.append(
|
| 811 |
+
{"row": action.row, "column": action.column, "type": issue.issue_type}
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
def _inject_noise(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 815 |
+
"""Apply deterministic observation noise (epsilon=0.15)."""
|
| 816 |
+
if not self._noise_rng:
|
| 817 |
+
return rows
|
| 818 |
+
noisy = []
|
| 819 |
+
for row in rows:
|
| 820 |
+
row_copy = dict(row)
|
| 821 |
+
if self._noise_rng.random() < _NOISE_EPSILON:
|
| 822 |
+
cols = [k for k in row_copy if k != "_row_index"]
|
| 823 |
+
if cols:
|
| 824 |
+
col = self._noise_rng.choice(cols)
|
| 825 |
+
val = row_copy[col]
|
| 826 |
+
if isinstance(val, str) and len(val) > 3:
|
| 827 |
+
row_copy[col] = (
|
| 828 |
+
val[: -(self._noise_rng.randint(1, 3))]
|
| 829 |
+
if self._noise_rng.random() < 0.5
|
| 830 |
+
else val.swapcase()
|
| 831 |
+
)
|
| 832 |
+
noisy.append(row_copy)
|
| 833 |
+
return noisy
|
| 834 |
+
|
| 835 |
+
def _compute_terminal(self) -> float:
|
| 836 |
+
"""Compute terminal score."""
|
| 837 |
+
fixable = [i for i in self._ground_truth if i.expected is not None]
|
| 838 |
+
metrics = EpisodeMetrics(
|
| 839 |
+
found_issues=len(self._found_issues),
|
| 840 |
+
total_issues=len(self._ground_truth),
|
| 841 |
+
fixed_issues=len(self._fixed_issues),
|
| 842 |
+
fixable_issues=len(fixable),
|
| 843 |
+
false_positives=self._false_positives,
|
| 844 |
+
)
|
| 845 |
+
return self._reward_engine.compute_terminal_score(metrics)
|
| 846 |
+
|
| 847 |
+
def _error_step(self, message: str) -> StepResult:
|
| 848 |
+
"""Build error StepResult."""
|
| 849 |
+
tr = ToolResult(
|
| 850 |
+
action_type="ERROR", success=False, error={"verdict": "error", "reason": message}
|
| 851 |
+
)
|
| 852 |
+
self._tool_history.append(tr)
|
| 853 |
+
self._cumulative_reward += P_INVALID
|
| 854 |
+
done = self._step_count >= self._max_steps
|
| 855 |
+
if done:
|
| 856 |
+
self._is_done = True
|
| 857 |
+
return StepResult(
|
| 858 |
+
observation=DataForgeObservation(
|
| 859 |
+
step_budget_remaining=max(0, self._max_steps - self._step_count),
|
| 860 |
+
tool_usage_history=list(self._tool_history[-_TOOL_HISTORY_LIMIT:]),
|
| 861 |
+
latest_result=tr,
|
| 862 |
+
done=done,
|
| 863 |
+
reward=P_INVALID,
|
| 864 |
+
cumulative_reward=self._cumulative_reward,
|
| 865 |
+
scratchpad_summary=self._scratchpad.summary(),
|
| 866 |
+
),
|
| 867 |
+
reward=P_INVALID,
|
| 868 |
+
done=done,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
def _terminal_result(self, reward: float) -> StepResult:
|
| 872 |
+
"""Build terminal StepResult for already-done episodes."""
|
| 873 |
+
return StepResult(
|
| 874 |
+
observation=DataForgeObservation(
|
| 875 |
+
step_budget_remaining=0,
|
| 876 |
+
done=True,
|
| 877 |
+
reward=reward,
|
| 878 |
+
cumulative_reward=self._cumulative_reward,
|
| 879 |
+
scratchpad_summary=self._scratchpad.summary(),
|
| 880 |
+
tool_usage_history=list(self._tool_history[-_TOOL_HISTORY_LIMIT:]),
|
| 881 |
+
),
|
| 882 |
+
reward=reward,
|
| 883 |
+
done=True,
|
| 884 |
+
)
|
dataforge/env/observation.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Observation builder for the DataForge RL environment.
|
| 2 |
+
|
| 3 |
+
Constructs agent-visible observations containing partial data views,
|
| 4 |
+
scratchpad summaries, tool results, and step budget information.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
__all__ = ["DataForgeObservation", "ToolResult"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ToolResult(BaseModel):
|
| 17 |
+
"""Result of a single tool-use action.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
action_type: The action type that produced this result.
|
| 21 |
+
success: Whether the action succeeded.
|
| 22 |
+
data: Action-specific result data (rows, stats, matches, etc.).
|
| 23 |
+
error: Structured error information if the action failed.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
action_type: str
|
| 27 |
+
success: bool = True
|
| 28 |
+
data: Any = None
|
| 29 |
+
error: dict[str, Any] | None = None
|
| 30 |
+
|
| 31 |
+
model_config = {"frozen": True}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DataForgeObservation(BaseModel):
|
| 35 |
+
"""Agent-visible observation returned after each environment step.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
visible_rows: Dataset rows returned by INSPECT_ROWS or reset.
|
| 39 |
+
detector_hints: Optional hints from detectors (partial ground truth).
|
| 40 |
+
scratchpad_summary: Compact summary of the agent's scratchpad.
|
| 41 |
+
step_budget_remaining: Steps left before auto-finalize.
|
| 42 |
+
tool_usage_history: Last 5 tool results for context.
|
| 43 |
+
latest_result: Result of the most recent action.
|
| 44 |
+
done: Whether the episode has ended.
|
| 45 |
+
reward: Step reward.
|
| 46 |
+
cumulative_reward: Running total reward for the episode.
|
| 47 |
+
metadata: Additional key-value metadata.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
visible_rows: list[dict[str, Any]] | None = None
|
| 51 |
+
detector_hints: list[str] | None = None
|
| 52 |
+
scratchpad_summary: str = ""
|
| 53 |
+
step_budget_remaining: int = 0
|
| 54 |
+
tool_usage_history: list[ToolResult] = Field(default_factory=list)
|
| 55 |
+
latest_result: ToolResult | None = None
|
| 56 |
+
done: bool = False
|
| 57 |
+
reward: float = 0.0
|
| 58 |
+
cumulative_reward: float = 0.0
|
| 59 |
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
| 60 |
+
|
| 61 |
+
model_config = {"frozen": True}
|
dataforge/env/openenv_core.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenEnv-core adapter for the DataForge RL environment."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from pydantic import Field
|
| 8 |
+
|
| 9 |
+
from dataforge.env.environment import DataForgeEnv
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from openenv.core.env_server import (
|
| 13 |
+
Action as OpenEnvAction,
|
| 14 |
+
)
|
| 15 |
+
from openenv.core.env_server import (
|
| 16 |
+
Environment as OpenEnvEnvironment,
|
| 17 |
+
)
|
| 18 |
+
from openenv.core.env_server import (
|
| 19 |
+
Observation as OpenEnvObservation,
|
| 20 |
+
)
|
| 21 |
+
from openenv.core.env_server import (
|
| 22 |
+
create_app,
|
| 23 |
+
)
|
| 24 |
+
except ImportError as exc: # pragma: no cover - exercised only without openenv extra
|
| 25 |
+
raise RuntimeError(
|
| 26 |
+
"The OpenEnv adapter requires the openenv extra: "
|
| 27 |
+
"pip install 'dataforge15[openenv]'."
|
| 28 |
+
) from exc
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DataForgeOpenEnvAction(OpenEnvAction):
|
| 32 |
+
"""OpenEnv action wrapper for DataForge's typed action payloads."""
|
| 33 |
+
|
| 34 |
+
action_type: str = Field(min_length=1)
|
| 35 |
+
row_indices: list[int] | None = None
|
| 36 |
+
column_names: list[str] | None = None
|
| 37 |
+
query: str | None = None
|
| 38 |
+
sql: str | None = None
|
| 39 |
+
test_type: str | None = None
|
| 40 |
+
test: str | None = None
|
| 41 |
+
column: str | None = None
|
| 42 |
+
threshold: float | None = None
|
| 43 |
+
pattern: str | None = None
|
| 44 |
+
regex: str | None = None
|
| 45 |
+
expect_match: bool | None = None
|
| 46 |
+
claim: str | None = None
|
| 47 |
+
affected_rows: list[int] | None = None
|
| 48 |
+
affected_columns: list[str] | None = None
|
| 49 |
+
root_cause_type: str | None = None
|
| 50 |
+
error_indices: list[int] | None = None
|
| 51 |
+
row: int | None = None
|
| 52 |
+
issue_type: str | None = None
|
| 53 |
+
new_value: str | None = None
|
| 54 |
+
proposed_value: str | None = None
|
| 55 |
+
justification: str | None = None
|
| 56 |
+
fix_type: str | None = None
|
| 57 |
+
|
| 58 |
+
def as_dataforge_payload(self) -> dict[str, Any]:
|
| 59 |
+
"""Return the action payload expected by ``DataForgeEnv.step``."""
|
| 60 |
+
payload = self.model_dump(exclude_none=True)
|
| 61 |
+
payload.pop("metadata", None)
|
| 62 |
+
return payload
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DataForgeOpenEnvObservation(OpenEnvObservation):
|
| 66 |
+
"""OpenEnv observation model mirroring DataForge's native observation."""
|
| 67 |
+
|
| 68 |
+
visible_rows: list[dict[str, Any]] | None = None
|
| 69 |
+
detector_hints: list[str] | None = None
|
| 70 |
+
scratchpad_summary: str = ""
|
| 71 |
+
step_budget_remaining: int = 0
|
| 72 |
+
tool_usage_history: list[dict[str, Any]] = Field(default_factory=list)
|
| 73 |
+
latest_result: dict[str, Any] | None = None
|
| 74 |
+
cumulative_reward: float = 0.0
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _to_openenv_observation(payload: dict[str, Any]) -> DataForgeOpenEnvObservation:
|
| 78 |
+
"""Convert a native DataForge observation dictionary into OpenEnv shape."""
|
| 79 |
+
return DataForgeOpenEnvObservation(
|
| 80 |
+
visible_rows=payload.get("visible_rows"),
|
| 81 |
+
detector_hints=payload.get("detector_hints"),
|
| 82 |
+
scratchpad_summary=str(payload.get("scratchpad_summary", "")),
|
| 83 |
+
step_budget_remaining=int(payload.get("step_budget_remaining", 0)),
|
| 84 |
+
tool_usage_history=list(payload.get("tool_usage_history") or []),
|
| 85 |
+
latest_result=payload.get("latest_result"),
|
| 86 |
+
done=bool(payload.get("done", False)),
|
| 87 |
+
reward=payload.get("reward"),
|
| 88 |
+
cumulative_reward=float(payload.get("cumulative_reward", 0.0)),
|
| 89 |
+
metadata=dict(payload.get("metadata") or {}),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class DataForgeOpenEnv(OpenEnvEnvironment):
|
| 94 |
+
"""OpenEnv-native environment wrapper."""
|
| 95 |
+
|
| 96 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 97 |
+
|
| 98 |
+
def __init__(self) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
self._env = DataForgeEnv()
|
| 101 |
+
self._last_observation: DataForgeOpenEnvObservation | None = None
|
| 102 |
+
|
| 103 |
+
def reset(
|
| 104 |
+
self,
|
| 105 |
+
seed: int | None = None,
|
| 106 |
+
episode_id: str | None = None,
|
| 107 |
+
**kwargs: Any,
|
| 108 |
+
) -> DataForgeOpenEnvObservation:
|
| 109 |
+
"""Reset the wrapped DataForge environment."""
|
| 110 |
+
del episode_id, kwargs
|
| 111 |
+
result = self._env.reset(seed=seed)
|
| 112 |
+
observation = _to_openenv_observation(result.observation.model_dump(mode="json"))
|
| 113 |
+
self._last_observation = observation
|
| 114 |
+
return observation
|
| 115 |
+
|
| 116 |
+
def step(
|
| 117 |
+
self,
|
| 118 |
+
action: DataForgeOpenEnvAction,
|
| 119 |
+
timeout_s: float | None = None,
|
| 120 |
+
**kwargs: Any,
|
| 121 |
+
) -> DataForgeOpenEnvObservation:
|
| 122 |
+
"""Step the wrapped DataForge environment."""
|
| 123 |
+
del timeout_s, kwargs
|
| 124 |
+
result = self._env.step(action.as_dataforge_payload())
|
| 125 |
+
observation = _to_openenv_observation(result.observation.model_dump(mode="json"))
|
| 126 |
+
self._last_observation = observation
|
| 127 |
+
return observation
|
| 128 |
+
|
| 129 |
+
def state(self) -> DataForgeOpenEnvObservation:
|
| 130 |
+
"""Return the latest observation or reset lazily."""
|
| 131 |
+
if self._last_observation is None:
|
| 132 |
+
return self.reset()
|
| 133 |
+
return self._last_observation
|
| 134 |
+
|
| 135 |
+
def close(self) -> None:
|
| 136 |
+
"""Close the wrapped environment."""
|
| 137 |
+
self._env.close()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
app = create_app(
|
| 141 |
+
DataForgeOpenEnv,
|
| 142 |
+
DataForgeOpenEnvAction,
|
| 143 |
+
DataForgeOpenEnvObservation,
|
| 144 |
+
env_name="dataforge-env",
|
| 145 |
+
max_concurrent_envs=64,
|
| 146 |
+
)
|
dataforge/env/reward.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward engine for the DataForge RL environment.
|
| 2 |
+
|
| 3 |
+
All constants and formulas are derived bit-for-bit from REWARD_DESIGN.md.
|
| 4 |
+
|
| 5 |
+
Terminal score: detection_rate * 0.40 + fix_rate * 0.60 - false_positives * fp_rate
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"DETECTION_WEIGHT",
|
| 14 |
+
"FALSE_POS_PENALTY_RATE",
|
| 15 |
+
"FIX_WEIGHT",
|
| 16 |
+
"LATE_STEP_THRESHOLD",
|
| 17 |
+
"P_FALSE_POS",
|
| 18 |
+
"P_INVALID",
|
| 19 |
+
"P_LATE_STEP",
|
| 20 |
+
"P_REINSPECT",
|
| 21 |
+
"P_WRONG_FIX",
|
| 22 |
+
"R_DIAGNOSE",
|
| 23 |
+
"R_EXPLORE",
|
| 24 |
+
"R_FIX",
|
| 25 |
+
"R_FIX_PARTIAL",
|
| 26 |
+
"R_JUSTIFY_BONUS",
|
| 27 |
+
"R_ROOT_CAUSE",
|
| 28 |
+
"R_TYPE_BONUS",
|
| 29 |
+
"SPAM_THRESHOLD",
|
| 30 |
+
"EpisodeMetrics",
|
| 31 |
+
"RewardEngine",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
# Positive rewards
|
| 35 |
+
R_DIAGNOSE: float = 0.10
|
| 36 |
+
R_TYPE_BONUS: float = 0.05
|
| 37 |
+
R_FIX: float = 0.15
|
| 38 |
+
R_FIX_PARTIAL: float = 0.075
|
| 39 |
+
R_JUSTIFY_BONUS: float = 0.05
|
| 40 |
+
R_EXPLORE: float = 0.01
|
| 41 |
+
R_ROOT_CAUSE: float = 0.10
|
| 42 |
+
|
| 43 |
+
# Negative penalties
|
| 44 |
+
P_FALSE_POS: float = -0.05
|
| 45 |
+
P_WRONG_FIX: float = -0.08
|
| 46 |
+
P_LATE_STEP: float = -0.02
|
| 47 |
+
P_INVALID: float = -0.01
|
| 48 |
+
P_REINSPECT: float = -0.01
|
| 49 |
+
|
| 50 |
+
# Thresholds
|
| 51 |
+
LATE_STEP_THRESHOLD: float = 0.80
|
| 52 |
+
DETECTION_WEIGHT: float = 0.40
|
| 53 |
+
FIX_WEIGHT: float = 0.60
|
| 54 |
+
FALSE_POS_PENALTY_RATE: float = 0.05
|
| 55 |
+
SPAM_THRESHOLD: float = 2.0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class EpisodeMetrics:
|
| 60 |
+
"""Accumulated metrics for terminal score computation."""
|
| 61 |
+
|
| 62 |
+
found_issues: int = 0
|
| 63 |
+
total_issues: int = 0
|
| 64 |
+
fixed_issues: int = 0
|
| 65 |
+
fixable_issues: int = 0
|
| 66 |
+
false_positives: int = 0
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def total_diagnoses(self) -> int:
|
| 70 |
+
"""Total diagnosis attempts (correct + incorrect)."""
|
| 71 |
+
return self.found_issues + self.false_positives
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class RewardEngine:
|
| 75 |
+
"""Computes dense per-step and terminal rewards."""
|
| 76 |
+
|
| 77 |
+
def compute_terminal_score(self, metrics: EpisodeMetrics) -> float:
|
| 78 |
+
"""Compute terminal score per REWARD_DESIGN.md formula."""
|
| 79 |
+
if metrics.total_issues == 0:
|
| 80 |
+
return 0.0
|
| 81 |
+
detection_rate = metrics.found_issues / metrics.total_issues
|
| 82 |
+
fix_rate = (
|
| 83 |
+
metrics.fixed_issues / metrics.fixable_issues if metrics.fixable_issues > 0 else 0.0
|
| 84 |
+
)
|
| 85 |
+
fp_rate = FALSE_POS_PENALTY_RATE
|
| 86 |
+
if (
|
| 87 |
+
metrics.total_issues > 0
|
| 88 |
+
and metrics.total_diagnoses > SPAM_THRESHOLD * metrics.total_issues
|
| 89 |
+
):
|
| 90 |
+
fp_rate *= 2.0
|
| 91 |
+
penalty = metrics.false_positives * fp_rate
|
| 92 |
+
raw = detection_rate * DETECTION_WEIGHT + fix_rate * FIX_WEIGHT - penalty
|
| 93 |
+
return round(max(0.0, min(1.0, raw)), 4)
|
| 94 |
+
|
| 95 |
+
def compute_late_penalty(self, step: int, max_steps: int) -> float:
|
| 96 |
+
"""Return P_LATE_STEP if past 80% budget, else 0.0."""
|
| 97 |
+
threshold = int(max_steps * LATE_STEP_THRESHOLD)
|
| 98 |
+
return P_LATE_STEP if step > threshold else 0.0
|
| 99 |
+
|
| 100 |
+
def compute_exploration_bonus(
|
| 101 |
+
self,
|
| 102 |
+
new_row_indices: set[int],
|
| 103 |
+
inspected_rows: set[int],
|
| 104 |
+
total_rows: int,
|
| 105 |
+
ground_truth_rows: set[int],
|
| 106 |
+
found_issue_rows: set[int],
|
| 107 |
+
) -> float:
|
| 108 |
+
"""Compute exploration bonus for newly-inspected rows."""
|
| 109 |
+
if not new_row_indices:
|
| 110 |
+
return P_REINSPECT
|
| 111 |
+
undiscovered = sum(
|
| 112 |
+
1 for r in new_row_indices if r in ground_truth_rows and r not in found_issue_rows
|
| 113 |
+
)
|
| 114 |
+
bonus = undiscovered * R_EXPLORE
|
| 115 |
+
if total_rows > 0:
|
| 116 |
+
all_inspected = inspected_rows | new_row_indices
|
| 117 |
+
coverage_ratio = len(all_inspected) / total_rows
|
| 118 |
+
bonus += len(new_row_indices) * R_EXPLORE * 0.5 * (1.0 - coverage_ratio)
|
| 119 |
+
return bonus
|
| 120 |
+
|
| 121 |
+
def diagnose_reward(self, type_match: bool) -> float:
|
| 122 |
+
"""Reward for correct diagnosis."""
|
| 123 |
+
return R_DIAGNOSE + (R_TYPE_BONUS if type_match else 0.0)
|
| 124 |
+
|
| 125 |
+
def fix_reward(self, exact: bool, has_justification: bool) -> float:
|
| 126 |
+
"""Reward for correct fix."""
|
| 127 |
+
reward = R_FIX if exact else R_FIX_PARTIAL
|
| 128 |
+
return reward + (R_JUSTIFY_BONUS if has_justification else 0.0)
|
dataforge/env/server.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server for the DataForge RL environment.
|
| 2 |
+
|
| 3 |
+
Provides OpenEnv-compatible HTTP endpoints:
|
| 4 |
+
POST /reset — Start a new episode
|
| 5 |
+
POST /step — Execute an action
|
| 6 |
+
GET /state — Return current state snapshot
|
| 7 |
+
POST /close — No-op shutdown
|
| 8 |
+
GET /health — Liveness check
|
| 9 |
+
GET /metadata — Environment metadata
|
| 10 |
+
GET /schema — Action/observation JSON schemas
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
from threading import RLock
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 21 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 22 |
+
from pydantic import TypeAdapter
|
| 23 |
+
|
| 24 |
+
from dataforge.agent.tool_actions import Action
|
| 25 |
+
from dataforge.env.environment import DataForgeEnv, EnvState
|
| 26 |
+
from dataforge.env.observation import DataForgeObservation
|
| 27 |
+
from dataforge.http.problem import problem_exception_handler
|
| 28 |
+
from dataforge.observability import configure_fastapi_observability
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger("dataforge.env.server")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _build_cors_origins() -> list[str]:
|
| 34 |
+
"""Build the explicit OpenEnv CORS allowlist from the environment."""
|
| 35 |
+
raw_origins = os.environ.get("DATAFORGE_OPENENV_ORIGINS", "")
|
| 36 |
+
return [origin.strip() for origin in raw_origins.split(",") if origin.strip()]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _build_cors_origin_regex() -> str | None:
|
| 40 |
+
"""Allow local browser development only when explicitly enabled."""
|
| 41 |
+
if os.environ.get("DATAFORGE_OPENENV_DEV") != "1":
|
| 42 |
+
return None
|
| 43 |
+
return r"^http://(?:localhost|127(?:\.\d{1,3}){3})(?::\d+)?$"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
app = FastAPI(
|
| 47 |
+
title="DataForge Environment",
|
| 48 |
+
description="OpenEnv-compatible RL environment for data-quality repair.",
|
| 49 |
+
version="0.1.0",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
app.add_middleware(
|
| 53 |
+
CORSMiddleware,
|
| 54 |
+
allow_origins=_build_cors_origins(),
|
| 55 |
+
allow_origin_regex=_build_cors_origin_regex(),
|
| 56 |
+
allow_credentials=False,
|
| 57 |
+
allow_methods=["GET", "POST", "OPTIONS"],
|
| 58 |
+
allow_headers=["*"],
|
| 59 |
+
)
|
| 60 |
+
app.add_exception_handler(HTTPException, problem_exception_handler)
|
| 61 |
+
configure_fastapi_observability(app, service_name="dataforge-openenv")
|
| 62 |
+
|
| 63 |
+
_registry_lock = RLock()
|
| 64 |
+
_default_env = DataForgeEnv()
|
| 65 |
+
_sessions: dict[str, DataForgeEnv] = {}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _get_env(episode_id: str | None) -> DataForgeEnv:
|
| 69 |
+
"""Resolve an environment by episode id, preserving legacy no-id behavior."""
|
| 70 |
+
if not episode_id:
|
| 71 |
+
return _default_env
|
| 72 |
+
with _registry_lock:
|
| 73 |
+
try:
|
| 74 |
+
return _sessions[episode_id]
|
| 75 |
+
except KeyError as exc:
|
| 76 |
+
raise HTTPException(
|
| 77 |
+
status_code=404,
|
| 78 |
+
detail={"error": "episode_not_found", "episode_id": episode_id},
|
| 79 |
+
) from exc
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _remember_env(env: DataForgeEnv, episode_id: str) -> None:
|
| 83 |
+
"""Register a session and update the legacy default environment."""
|
| 84 |
+
global _default_env
|
| 85 |
+
with _registry_lock:
|
| 86 |
+
_sessions[episode_id] = env
|
| 87 |
+
_default_env = env
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@app.post("/reset")
|
| 91 |
+
async def reset(seed: int | None = None) -> dict[str, Any]:
|
| 92 |
+
"""Reset the environment for a new episode."""
|
| 93 |
+
env = DataForgeEnv()
|
| 94 |
+
result = env.reset(seed=seed)
|
| 95 |
+
episode_id = str(result.info["episode_id"])
|
| 96 |
+
_remember_env(env, episode_id)
|
| 97 |
+
return result.model_dump(mode="json")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@app.post("/step")
|
| 101 |
+
async def step(action: dict[str, Any]) -> dict[str, Any]:
|
| 102 |
+
"""Execute one agent action."""
|
| 103 |
+
action_payload = dict(action)
|
| 104 |
+
raw_episode_id = action_payload.pop("episode_id", None)
|
| 105 |
+
episode_id = str(raw_episode_id) if raw_episode_id else None
|
| 106 |
+
result = _get_env(episode_id).step(action_payload)
|
| 107 |
+
return result.model_dump(mode="json")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@app.get("/state")
|
| 111 |
+
async def state(episode_id: str | None = None) -> dict[str, Any]:
|
| 112 |
+
"""Return current environment state snapshot."""
|
| 113 |
+
result = _get_env(episode_id).state()
|
| 114 |
+
return result.model_dump(mode="json")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@app.post("/close")
|
| 118 |
+
async def close(request: Request, episode_id: str | None = None) -> dict[str, Any]:
|
| 119 |
+
"""No-op close endpoint for OpenEnv compatibility."""
|
| 120 |
+
body_episode_id: str | None = None
|
| 121 |
+
if episode_id is None:
|
| 122 |
+
try:
|
| 123 |
+
payload = await request.json()
|
| 124 |
+
except Exception:
|
| 125 |
+
payload = None
|
| 126 |
+
if isinstance(payload, dict) and payload.get("episode_id"):
|
| 127 |
+
body_episode_id = str(payload["episode_id"])
|
| 128 |
+
|
| 129 |
+
target_episode_id = episode_id or body_episode_id
|
| 130 |
+
env = _get_env(target_episode_id)
|
| 131 |
+
env.close()
|
| 132 |
+
if target_episode_id:
|
| 133 |
+
with _registry_lock:
|
| 134 |
+
_sessions.pop(target_episode_id, None)
|
| 135 |
+
return {"status": "closed", "episode_id": target_episode_id}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@app.get("/health")
|
| 139 |
+
async def health() -> dict[str, Any]:
|
| 140 |
+
"""Liveness check."""
|
| 141 |
+
return {"status": "healthy", "environment": "dataforge-env"}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@app.get("/metadata")
|
| 145 |
+
async def metadata() -> dict[str, Any]:
|
| 146 |
+
"""Environment metadata for OpenEnv discovery."""
|
| 147 |
+
return {
|
| 148 |
+
"name": "dataforge-env",
|
| 149 |
+
"version": "0.1.0",
|
| 150 |
+
"description": (
|
| 151 |
+
"DataForge RL Environment — agents learn to detect, diagnose, "
|
| 152 |
+
"and repair data-quality issues in tabular datasets."
|
| 153 |
+
),
|
| 154 |
+
"action_types": [
|
| 155 |
+
"INSPECT_ROWS",
|
| 156 |
+
"SQL_QUERY",
|
| 157 |
+
"STAT_TEST",
|
| 158 |
+
"PATTERN_MATCH",
|
| 159 |
+
"HYPOTHESIS",
|
| 160 |
+
"ROOT_CAUSE",
|
| 161 |
+
"DIAGNOSE",
|
| 162 |
+
"FIX",
|
| 163 |
+
],
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@app.get("/schema")
|
| 168 |
+
async def schema() -> dict[str, Any]:
|
| 169 |
+
"""Return JSON schemas for action and observation models."""
|
| 170 |
+
action_adapter: TypeAdapter[Action] = TypeAdapter(Action)
|
| 171 |
+
return {
|
| 172 |
+
"action": action_adapter.json_schema(),
|
| 173 |
+
"observation": DataForgeObservation.model_json_schema(),
|
| 174 |
+
"state": EnvState.model_json_schema(),
|
| 175 |
+
}
|
dataforge/evaluation_contract.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public evaluation evidence models for DataForge repair releases."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
+
from typing import Any, Literal
|
| 8 |
+
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
InferabilityLabel = Literal[
|
| 12 |
+
"deterministic_normalization",
|
| 13 |
+
"context_derivable",
|
| 14 |
+
"external_reference_required",
|
| 15 |
+
"not_inferable_from_prompt",
|
| 16 |
+
]
|
| 17 |
+
PROMOTION_SLICE: InferabilityLabel = "deterministic_normalization"
|
| 18 |
+
ABSTENTION_SLICES = frozenset({"external_reference_required", "not_inferable_from_prompt"})
|
| 19 |
+
AUXILIARY_SLICES = frozenset({"context_derivable"})
|
| 20 |
+
PromotionStatus = Literal[
|
| 21 |
+
"diagnostic_only",
|
| 22 |
+
"diagnostic_promoted",
|
| 23 |
+
"quality_improved_verified",
|
| 24 |
+
"public_quality_milestone",
|
| 25 |
+
"rejected",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class EvaluationTaskV2(BaseModel):
|
| 30 |
+
"""One auditable, source-stable model grading task.
|
| 31 |
+
|
| 32 |
+
Ground truth is retained for local grading but excluded from normal JSON
|
| 33 |
+
serialization so prompts and public reports cannot accidentally leak labels.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
schema_version: Literal["evaluation_task_v2"] = "evaluation_task_v2"
|
| 37 |
+
task_id: str = Field(min_length=1)
|
| 38 |
+
prompt_hash: str = Field(min_length=64, max_length=64)
|
| 39 |
+
dataset_sha: str = Field(min_length=1)
|
| 40 |
+
split_id: str = Field(min_length=1)
|
| 41 |
+
inferability: InferabilityLabel
|
| 42 |
+
prompt: dict[str, Any]
|
| 43 |
+
allowed_columns: list[str] = Field(min_length=1)
|
| 44 |
+
valid_rows: list[int] = Field(min_length=1)
|
| 45 |
+
provenance: dict[str, Any]
|
| 46 |
+
hidden_ground_truth: list[dict[str, Any]] = Field(default_factory=list, exclude=True)
|
| 47 |
+
|
| 48 |
+
model_config = {"frozen": True}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ReleaseEvidenceV2(BaseModel):
|
| 52 |
+
"""Serializable release-gate evidence for model and benchmark promotion."""
|
| 53 |
+
|
| 54 |
+
schema_version: Literal["release_evidence_v2"] = "release_evidence_v2"
|
| 55 |
+
model_repo: str = Field(min_length=1)
|
| 56 |
+
model_sha: str = Field(min_length=1)
|
| 57 |
+
dataset_repo: str = Field(min_length=1)
|
| 58 |
+
dataset_sha: str = Field(min_length=1)
|
| 59 |
+
strict_macro_f1: float = Field(ge=0.0, le=1.0)
|
| 60 |
+
canonicalized_macro_f1: float = Field(ge=0.0, le=1.0)
|
| 61 |
+
parse_success_rate: float = Field(ge=0.0, le=1.0)
|
| 62 |
+
schema_case_error_count: int = Field(ge=0)
|
| 63 |
+
promotion_slice: InferabilityLabel = PROMOTION_SLICE
|
| 64 |
+
slice_scores: dict[InferabilityLabel, dict[str, float | int]] = Field(default_factory=dict)
|
| 65 |
+
inferability_slice_scores: dict[InferabilityLabel, float] = Field(default_factory=dict)
|
| 66 |
+
package_versions: dict[str, str] = Field(default_factory=dict)
|
| 67 |
+
promotion_status: PromotionStatus
|
| 68 |
+
gate_failures: list[str] = Field(default_factory=list)
|
| 69 |
+
|
| 70 |
+
model_config = {"frozen": True}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def prompt_sha256(prompt: dict[str, Any]) -> str:
|
| 74 |
+
"""Hash a prompt payload with stable JSON serialization."""
|
| 75 |
+
encoded = json.dumps(prompt, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
| 76 |
+
return hashlib.sha256(encoded).hexdigest()
|
dataforge/fixtures/hospital_10rows.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
provider_number,hospital_name,city,state,zip_code,phone_number,rating,mortality_rate,readmission_rate,er_wait_time
|
| 2 |
+
PRV001,General Hospital,Springfield,IL,62701,2175550101,4.2,0.023,0.145,28
|
| 3 |
+
PRV002,St. Mary Medical Center,Chicago,IL,60601,3125550202,3.8,0.031,0.162,35
|
| 4 |
+
PRV001,Springfield Medical,Springfield,IL,62701,2175550303,4.5,0.019,0.138,22
|
| 5 |
+
PRV003,Mercy Hospital,Peoria,IL,61602,3095550404,3.5,0.028,0.158,31
|
| 6 |
+
PRV004,Northwestern Memorial,Chicago,IL,60611,not available,4.1,0.025,0.149,26
|
| 7 |
+
PRV005,Rush University MC,Chicago,IL,60612,3125550606,45.0,0.022,0.141,29
|
| 8 |
+
PRV006,Advocate Christ,Oak Lawn,IL,60453,7085550707,3.9,0.027,0.155,33
|
| 9 |
+
PRV007,Loyola University MC,Maywood,IL,60153,7085550808,4.3,0.020,0.142,25
|
| 10 |
+
PRV008,Presence St. Joseph,Joliet,IL,60435,8155550909,4.0,0.026,0.151,30
|
| 11 |
+
PRV009,Edward Hospital,Naperville,IL,60540,6305551010,3.7,0.029,0.160,34
|
dataforge/fixtures/hospital_schema.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hospital dataset schema for DataForge profile command.
|
| 2 |
+
|
| 3 |
+
columns:
|
| 4 |
+
provider_number: str
|
| 5 |
+
hospital_name: str
|
| 6 |
+
city: str
|
| 7 |
+
state: str
|
| 8 |
+
zip_code: str
|
| 9 |
+
phone_number: str
|
| 10 |
+
rating: float
|
| 11 |
+
mortality_rate: float
|
| 12 |
+
readmission_rate: float
|
| 13 |
+
er_wait_time: int
|
| 14 |
+
|
| 15 |
+
functional_dependencies:
|
| 16 |
+
- determinant: [provider_number]
|
| 17 |
+
dependent: hospital_name
|
dataforge/http/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""HTTP helpers shared by DataForge backend surfaces."""
|
dataforge/http/problem.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RFC 9457 problem details helpers for FastAPI surfaces."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from collections.abc import Mapping
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from fastapi import HTTPException, Request
|
| 9 |
+
from fastapi.responses import JSONResponse
|
| 10 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ProblemDetail(BaseModel):
|
| 14 |
+
"""RFC 9457 problem detail response with extension members."""
|
| 15 |
+
|
| 16 |
+
type: str = Field(default="about:blank")
|
| 17 |
+
title: str
|
| 18 |
+
status: int
|
| 19 |
+
detail: str
|
| 20 |
+
instance: str | None = None
|
| 21 |
+
|
| 22 |
+
model_config = ConfigDict(strict=True, extra="allow")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def problem_body(
|
| 26 |
+
*,
|
| 27 |
+
status: int,
|
| 28 |
+
title: str,
|
| 29 |
+
detail: str,
|
| 30 |
+
type_: str = "about:blank",
|
| 31 |
+
instance: str | None = None,
|
| 32 |
+
**extensions: Any,
|
| 33 |
+
) -> dict[str, Any]:
|
| 34 |
+
"""Build a problem details JSON object."""
|
| 35 |
+
body = ProblemDetail(
|
| 36 |
+
type=type_,
|
| 37 |
+
title=title,
|
| 38 |
+
status=status,
|
| 39 |
+
detail=detail,
|
| 40 |
+
instance=instance,
|
| 41 |
+
**extensions,
|
| 42 |
+
)
|
| 43 |
+
return body.model_dump(mode="json", exclude_none=True)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def problem_response(
|
| 47 |
+
*,
|
| 48 |
+
status: int,
|
| 49 |
+
title: str,
|
| 50 |
+
detail: str,
|
| 51 |
+
type_: str = "about:blank",
|
| 52 |
+
instance: str | None = None,
|
| 53 |
+
headers: Mapping[str, str] | None = None,
|
| 54 |
+
**extensions: Any,
|
| 55 |
+
) -> JSONResponse:
|
| 56 |
+
"""Return an RFC 9457 JSON response."""
|
| 57 |
+
return JSONResponse(
|
| 58 |
+
status_code=status,
|
| 59 |
+
content=problem_body(
|
| 60 |
+
status=status,
|
| 61 |
+
title=title,
|
| 62 |
+
detail=detail,
|
| 63 |
+
type_=type_,
|
| 64 |
+
instance=instance,
|
| 65 |
+
**extensions,
|
| 66 |
+
),
|
| 67 |
+
headers=headers,
|
| 68 |
+
media_type="application/problem+json",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
| 73 |
+
"""Normalize FastAPI HTTPException values into problem details."""
|
| 74 |
+
raw_detail = exc.detail
|
| 75 |
+
extensions: dict[str, Any] = {}
|
| 76 |
+
if isinstance(raw_detail, dict):
|
| 77 |
+
error_code = str(raw_detail.get("error", "http_error"))
|
| 78 |
+
message = str(raw_detail.get("message") or raw_detail.get("detail") or error_code)
|
| 79 |
+
extensions.update(raw_detail)
|
| 80 |
+
else:
|
| 81 |
+
error_code = "http_error"
|
| 82 |
+
message = str(raw_detail)
|
| 83 |
+
|
| 84 |
+
return problem_response(
|
| 85 |
+
status=exc.status_code,
|
| 86 |
+
type_=f"https://dataforge.local/problems/{error_code}",
|
| 87 |
+
title=error_code.replace("_", " ").title(),
|
| 88 |
+
detail=message,
|
| 89 |
+
instance=str(request.url.path),
|
| 90 |
+
headers=exc.headers,
|
| 91 |
+
**extensions,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
async def problem_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
| 96 |
+
"""Adapter with the broad exception signature Starlette expects."""
|
| 97 |
+
if isinstance(exc, HTTPException):
|
| 98 |
+
return await http_exception_handler(request, exc)
|
| 99 |
+
raise exc
|
dataforge/integrations/dbt.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""The dbt integration lives in the separate ``dataforge15-dbt`` package."""
|
dataforge/observability.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optional OpenTelemetry hooks for DataForge backend surfaces."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from collections.abc import Iterator
|
| 7 |
+
from contextlib import contextmanager, nullcontext
|
| 8 |
+
from importlib import import_module
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
_SENSITIVE_ATTR_FRAGMENTS = ("authorization", "cookie", "token", "key", "secret", "password")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _otel_enabled() -> bool:
|
| 15 |
+
"""Return whether optional OpenTelemetry instrumentation is enabled."""
|
| 16 |
+
return os.environ.get("DATAFORGE_OTEL_ENABLED", "").strip().lower() in {
|
| 17 |
+
"1",
|
| 18 |
+
"true",
|
| 19 |
+
"yes",
|
| 20 |
+
"on",
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _safe_attrs(attributes: dict[str, Any]) -> dict[str, str | int | float | bool]:
|
| 25 |
+
"""Keep only scalar, non-sensitive telemetry attributes."""
|
| 26 |
+
safe: dict[str, str | int | float | bool] = {}
|
| 27 |
+
for key, value in attributes.items():
|
| 28 |
+
lowered = key.lower()
|
| 29 |
+
if any(fragment in lowered for fragment in _SENSITIVE_ATTR_FRAGMENTS):
|
| 30 |
+
continue
|
| 31 |
+
if lowered in {"row_values", "rows", "payload", "source_bytes", "csv"}:
|
| 32 |
+
continue
|
| 33 |
+
if isinstance(value, str | int | float | bool):
|
| 34 |
+
safe[key] = value
|
| 35 |
+
return safe
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def configure_fastapi_observability(app: Any, *, service_name: str) -> bool:
|
| 39 |
+
"""Instrument a FastAPI app when OpenTelemetry is explicitly enabled."""
|
| 40 |
+
if not _otel_enabled():
|
| 41 |
+
return False
|
| 42 |
+
try:
|
| 43 |
+
fastapi_instrumentation = import_module("opentelemetry.instrumentation.fastapi")
|
| 44 |
+
trace_module = import_module("opentelemetry.trace")
|
| 45 |
+
except ImportError:
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
app.state.dataforge_service_name = service_name
|
| 49 |
+
fastapi_instrumentation.FastAPIInstrumentor.instrument_app(
|
| 50 |
+
app,
|
| 51 |
+
tracer_provider=trace_module.get_tracer_provider(),
|
| 52 |
+
excluded_urls="/api/docs,/docs,/redoc,/openapi.json",
|
| 53 |
+
)
|
| 54 |
+
return True
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@contextmanager
|
| 58 |
+
def repair_stage_span(stage: str, **attributes: Any) -> Iterator[None]:
|
| 59 |
+
"""Create a repair-stage span when OpenTelemetry is available."""
|
| 60 |
+
if not _otel_enabled():
|
| 61 |
+
with nullcontext():
|
| 62 |
+
yield
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
trace_module = import_module("opentelemetry.trace")
|
| 67 |
+
except ImportError:
|
| 68 |
+
with nullcontext():
|
| 69 |
+
yield
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
tracer = trace_module.get_tracer("dataforge.repair")
|
| 73 |
+
with tracer.start_as_current_span(stage) as span:
|
| 74 |
+
for key, value in _safe_attrs(attributes).items():
|
| 75 |
+
span.set_attribute(key, value)
|
| 76 |
+
yield
|
dataforge/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
dataforge/release/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Release verification helpers for DataForge."""
|
| 2 |
+
|