Praneshrajan15 commited on
Commit
eed1cab
·
verified ·
1 Parent(s): fe6681f

Deploy DataForge playground API

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +63 -28
  2. README.md +21 -18
  3. README_MAIN.md +0 -96
  4. dataforge/__init__.py +115 -2
  5. dataforge/agent/__init__.py +16 -1
  6. dataforge/agent/providers.py +11 -3
  7. dataforge/agent/scratchpad.py +183 -0
  8. dataforge/agent/tool_actions.py +343 -0
  9. dataforge/bench/core.py +6 -0
  10. dataforge/bench/groq_client.py +306 -27
  11. dataforge/bench/methods.py +35 -0
  12. dataforge/bench/report.py +19 -13
  13. dataforge/bench/runner.py +45 -6
  14. dataforge/causal/__init__.py +21 -1
  15. dataforge/causal/dag.py +174 -0
  16. dataforge/causal/pc.py +232 -0
  17. dataforge/causal/root_cause.py +193 -0
  18. dataforge/cli/__init__.py +10 -4
  19. dataforge/cli/audit.py +70 -0
  20. dataforge/cli/bench.py +23 -4
  21. dataforge/cli/common.py +26 -4
  22. dataforge/cli/profile.py +61 -16
  23. dataforge/cli/release.py +39 -0
  24. dataforge/cli/repair.py +104 -249
  25. dataforge/cli/watch.py +142 -0
  26. dataforge/datasets/embedded/hospital/clean.csv +11 -0
  27. dataforge/datasets/embedded/hospital/dirty.csv +11 -0
  28. dataforge/datasets/real_world.py +37 -7
  29. dataforge/detectors/__init__.py +2 -4
  30. dataforge/detectors/base.py +5 -5
  31. dataforge/detectors/decimal_shift.py +11 -17
  32. dataforge/detectors/fd_violation.py +21 -24
  33. dataforge/detectors/type_mismatch.py +6 -13
  34. dataforge/engine/__init__.py +33 -1
  35. dataforge/engine/repair.py +670 -0
  36. dataforge/env/__init__.py +22 -1
  37. dataforge/env/environment.py +884 -0
  38. dataforge/env/observation.py +61 -0
  39. dataforge/env/openenv_core.py +146 -0
  40. dataforge/env/reward.py +128 -0
  41. dataforge/env/server.py +175 -0
  42. dataforge/evaluation_contract.py +76 -0
  43. dataforge/fixtures/hospital_10rows.csv +11 -0
  44. dataforge/fixtures/hospital_schema.yaml +17 -0
  45. dataforge/http/__init__.py +1 -0
  46. dataforge/http/problem.py +99 -0
  47. dataforge/integrations/dbt.py +1 -0
  48. dataforge/observability.py +76 -0
  49. dataforge/py.typed +1 -0
  50. dataforge/release/__init__.py +2 -0
Dockerfile CHANGED
@@ -1,28 +1,63 @@
1
- # DataForge Playground - Multi-stage Docker build for HF Spaces.
2
- FROM python:3.12-slim AS builder
3
- WORKDIR /build
4
- RUN apt-get update && \
5
- apt-get install -y --no-install-recommends gcc g++ && \
6
- rm -rf /var/lib/apt/lists/*
7
- COPY playground/api/requirements.txt /build/requirements.txt
8
- RUN pip install --no-cache-dir -r /build/requirements.txt
9
- COPY pyproject.toml /build/dataforge_src/pyproject.toml
10
- COPY README_MAIN.md /build/dataforge_src/README.md
11
- COPY dataforge/ /build/dataforge_src/dataforge/
12
- COPY constitutions/ /build/dataforge_src/constitutions/
13
- RUN pip install --no-cache-dir /build/dataforge_src
14
-
15
- FROM python:3.12-slim
16
- RUN useradd -m -u 1000 user
17
- COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
18
- COPY --from=builder /usr/local/bin /usr/local/bin
19
- COPY --from=builder /build/dataforge_src/constitutions /usr/local/lib/python3.12/site-packages/constitutions
20
- COPY playground/api/app.py /home/user/app/app.py
21
- COPY playground/api/samples/ /home/user/app/samples/
22
- COPY playground/web/ /home/user/app/web/
23
- USER user
24
- WORKDIR /home/user/app
25
- EXPOSE 7860
26
- ENV PORT=7860
27
- ENV DATAFORGE_PLAYGROUND_DEV=0
28
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1", "--timeout-keep-alive", "5"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DataForge Playground Multi-stage Docker build for HF Spaces.
2
+ #
3
+ # Target: <= 600 MB image. Runs as non-root UID 1000 (HF requirement).
4
+ # Single-worker uvicorn with --timeout-keep-alive 5 (slowloris mitigation).
5
+ #
6
+ # See specs/SPEC_playground.md §4 and §6.5.
7
+
8
+ # ============================================================
9
+ # Stage 1: builder — install all Python dependencies
10
+ # ============================================================
11
+ FROM python:3.12-slim AS builder
12
+
13
+ WORKDIR /build
14
+
15
+ # System deps for building wheels
16
+ RUN apt-get update && \
17
+ apt-get install -y --no-install-recommends gcc g++ && \
18
+ rm -rf /var/lib/apt/lists/*
19
+
20
+ # Install playground API requirements
21
+ COPY playground/api/requirements.txt /build/requirements.txt
22
+ RUN pip install --no-cache-dir -r /build/requirements.txt
23
+
24
+ # Copy dataforge source and install it
25
+ COPY pyproject.toml /build/dataforge_src/pyproject.toml
26
+ COPY README.md /build/dataforge_src/README.md
27
+ COPY dataforge/ /build/dataforge_src/dataforge/
28
+ COPY constitutions/ /build/dataforge_src/constitutions/
29
+ RUN pip install --no-cache-dir /build/dataforge_src
30
+
31
+ # ============================================================
32
+ # Stage 2: runtime — minimal image with only installed packages
33
+ # ============================================================
34
+ FROM python:3.12-slim
35
+
36
+ # HF Spaces requires non-root user with UID 1000
37
+ RUN useradd -m -u 1000 user
38
+
39
+ # Copy installed Python packages from builder
40
+ COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
41
+ COPY --from=builder /usr/local/bin /usr/local/bin
42
+
43
+ # Copy constitutions to the site-packages-relative path SafetyFilter expects.
44
+ COPY --from=builder /build/dataforge_src/constitutions /usr/local/lib/python3.12/site-packages/constitutions
45
+
46
+ # Copy application code
47
+ COPY playground/api/app.py /home/user/app/app.py
48
+ COPY playground/api/samples/ /home/user/app/samples/
49
+
50
+ # Switch to non-root user
51
+ USER user
52
+ WORKDIR /home/user/app
53
+
54
+ # Expose the port HF Spaces expects
55
+ EXPOSE 7860
56
+
57
+ # Environment
58
+ ENV PORT=7860
59
+ ENV DATAFORGE_PLAYGROUND_DEV=0
60
+
61
+ # Start uvicorn with single worker (slowapi in-memory limiter contract)
62
+ # and honor PORT for Hugging Face runtime assignment.
63
+ CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860} --workers 1 --timeout-keep-alive 5"]
README.md CHANGED
@@ -7,37 +7,40 @@ sdk: docker
7
  app_port: 7860
8
  pinned: false
9
  license: apache-2.0
10
- short_description: Upload a CSV, profile and dry-run-repair it in your browser.
11
  ---
12
 
13
- # DataForge Playground
14
 
15
- Upload a CSV file and instantly profile it for data-quality issues or
16
- preview proposed repairs all in your browser, no installation required.
 
17
 
18
- **What it does:**
19
 
20
- - **Profile**: Detects type mismatches, decimal shifts, and functional
21
- dependency violations using heuristic detectors.
22
- - **Repair (Dry Run)**: Proposes fixes through the full Safety → Verifier →
23
- Transaction pipeline, returning an ephemeral transaction journal.
 
24
 
25
- **What it does NOT do:**
26
 
27
- - No data is persisted. Your file is processed in memory and discarded.
28
- - No cookies, no analytics of file contents.
29
- - No LLM calls by default (opt-in only, requires a configured key).
 
30
 
31
- ## Run locally instead
32
 
33
  ```bash
34
- pip install dataforge
35
- dataforge profile your_data.csv
36
- dataforge repair your_data.csv --dry-run
37
  ```
38
 
39
  ## Source
40
 
41
- - Main repository: [github.com/Praneshrajan15/data-quality-env](https://github.com/Praneshrajan15/data-quality-env)
42
  - Spec: `specs/SPEC_playground.md`
43
  - License: Apache-2.0
 
7
  app_port: 7860
8
  pinned: false
9
  license: apache-2.0
10
+ short_description: Profile CSVs and dry-run safe repairs.
11
  ---
12
 
13
+ # DataForge Playground API
14
 
15
+ This is the API backend for the DataForge playground. The browser UI is deployed
16
+ separately through Cloudflare Workers Static Assets; this Hugging Face Docker
17
+ Space serves stateless CSV profiling and dry-run repair endpoints.
18
 
19
+ ## What It Does
20
 
21
+ - Profile: detects type mismatches, decimal shifts, and functional dependency
22
+ violations.
23
+ - Repair dry run: proposes fixes through SafetyFilter -> SMTVerifier and
24
+ returns an ephemeral transaction receipt without persisting user data.
25
+ - Samples: serves small deterministic CSV examples for the static frontend.
26
 
27
+ ## What It Does Not Do
28
 
29
+ - It does not persist uploaded files.
30
+ - It does not use cookies or analytics for file contents.
31
+ - It does not call an LLM by default.
32
+ - It does not perform autonomous production repair.
33
 
34
+ ## Run Locally
35
 
36
  ```bash
37
+ python -m pip install -e ".[dev]"
38
+ pip install -r playground/api/requirements.txt
39
+ uvicorn playground.api.app:app --reload --port 7860
40
  ```
41
 
42
  ## Source
43
 
44
+ - Main repository: `github.com/Praneshrajan15/data-quality-env`
45
  - Spec: `specs/SPEC_playground.md`
46
  - License: Apache-2.0
README_MAIN.md DELETED
@@ -1,96 +0,0 @@
1
- # DataForge
2
-
3
- DataForge currently ships a real Week 3 CLI for CSV profiling and repair.
4
-
5
- This repository now includes shipped detectors, deterministic repairers,
6
- constitutional safety gating, SMT-backed structural verification, reversible
7
- transaction logs, and real-world benchmark infrastructure. The hosted
8
- playground, warehouse integrations, and trained model family remain future
9
- work.
10
-
11
- ## Current Status
12
-
13
- - `dataforge profile`, `dataforge repair`, `dataforge revert`, and `dataforge bench`
14
- - Three shipped detectors: `type_mismatch`, `decimal_shift`, `fd_violation`
15
- - Three shipped repairers with safety + verifier gating in the apply path
16
- - Reversible transaction logs with byte-identical revert via source snapshots
17
- - Benchmark/report generation infrastructure for Hospital / Flights / Beers
18
- - `Makefile` targets for setup, lint, type-checking, and tests
19
- - CI plus unit / integration / property / adversarial coverage
20
-
21
- ## Benchmark Results
22
-
23
- <!-- BENCH:START -->
24
- Generated from `eval/results/agent_comparison.json`.
25
-
26
- | Method | Precision | Recall | F1 | Avg Steps | Quota Units |
27
- | --- | --- | --- | --- | --- | --- |
28
- | heuristic | 0.0000 | 0.0000 | 0.0000 | 134.33 | 0.0000 |
29
- | llm_react | Skipped | Skipped | Skipped | Skipped | Skipped |
30
- | llm_zeroshot | Skipped | Skipped | Skipped | Skipped | Skipped |
31
- | random | 0.0038 | 0.0003 | 0.0005 | 150.33 | 0.0000 |
32
-
33
- See `BENCHMARK_REPORT.md` for per-dataset tables, error bars, and citation-only SOTA rows.
34
-
35
- Skipped methods in this run: DATAFORGE_LLM_PROVIDER must be set to groq.
36
- <!-- BENCH:END -->
37
-
38
- ## Local Setup
39
-
40
- ```bash
41
- make setup
42
- make lint
43
- make type
44
- make test
45
- ```
46
-
47
- Verification works on Linux, macOS, or Windows (with Git Bash as the
48
- shell substrate for GNU Make). Requires Python 3.11 or 3.12
49
- (`requires-python = ">=3.11,<3.13"`).
50
-
51
- ### Windows-specific setup
52
-
53
- ```powershell
54
- # Install Python 3.12 and GNU Make if not present
55
- winget install -e --id Python.Python.3.12
56
- winget install -e --id ezwinports.make
57
-
58
- # Create and activate a project venv
59
- py -3.12 -m venv .venv
60
- .\.venv\Scripts\Activate.ps1
61
-
62
- # Install dependencies and verify
63
- python -m pip install -e ".[all]"
64
- make lint && make type && make test
65
- ```
66
-
67
- Git for Windows provides the Bash implementation the Makefile uses on Windows.
68
- Do not rely on `C:\Windows\System32\bash.exe` (WSL).
69
-
70
- ## Environment Variables
71
-
72
- Future provider keys belong in a root `.env` file that is gitignored and meant
73
- to be loaded with `python-dotenv`.
74
-
75
- - `GROQ_API_KEY`
76
- - `GEMINI_API_KEY`
77
- - `CEREBRAS_API_KEY`
78
- - `OPENROUTER_API_KEY`
79
- - `HF_TOKEN`
80
-
81
- ## Repository Docs
82
-
83
- - [.cursor/rules/dataforge.md](.cursor/rules/dataforge.md) — always-applied rules
84
- - [ARCHITECTURE.md](ARCHITECTURE.md) — system diagram and dependency justification
85
- - [DECISIONS.md](DECISIONS.md) — technical decision log
86
- - [CONTRIBUTING.md](CONTRIBUTING.md) — workflow and code standards
87
- - [CLAUDE.md](CLAUDE.md) — living knowledge base for Cursor sessions
88
- - [CURSOR_MASTER.md](CURSOR_MASTER.md) — full context and prompt pack
89
- - [META_CONTEXT.md](META_CONTEXT.md) — meta-context (read before writing code)
90
- - [FILE_STRUCTURE.md](FILE_STRUCTURE.md) — canonical target directory tree
91
- - [SECURITY.md](SECURITY.md) — vulnerability reporting policy
92
- - [specs/SPEC_TEMPLATE.md](specs/SPEC_TEMPLATE.md) — spec template for new modules
93
-
94
- ## License
95
-
96
- Apache-2.0. See [LICENSE](LICENSE).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataforge/__init__.py CHANGED
@@ -1,5 +1,118 @@
1
- """DataForge public package."""
2
 
3
- __all__ = ["__version__"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  __version__ = "0.1.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DataForge public package.
2
 
3
+ The root package is the stable facade for integration surfaces. Symbols are
4
+ resolved lazily so importing :mod:`dataforge` does not eagerly import pandas,
5
+ FastAPI-facing helpers, or the SMT stack.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from importlib import import_module
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ if TYPE_CHECKING:
14
+ from dataforge.cli.common import load_schema, read_csv, schema_from_mapping
15
+ from dataforge.detectors import Issue, Schema, Severity, run_all_detectors
16
+ from dataforge.engine.repair import (
17
+ CandidateFix,
18
+ RepairFailure,
19
+ RepairPipelineRequest,
20
+ RepairPipelineResult,
21
+ RepairReceipt,
22
+ VerifiedFix,
23
+ run_repair_pipeline,
24
+ )
25
+ from dataforge.repair_contract import CONTRACT_VERSION
26
+ from dataforge.repairers import ProposedFix
27
+ from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict
28
+ from dataforge.transactions.log import (
29
+ TransactionAuditReport,
30
+ TransactionAuditVerdict,
31
+ TransactionLogError,
32
+ verify_transaction_log,
33
+ )
34
+ from dataforge.transactions.revert import TransactionRevertError, revert_transaction
35
+ from dataforge.transactions.txn import CellFix, RepairTransaction
36
+ from dataforge.verifier import SMTVerifier, VerificationResult, VerificationVerdict
37
+
38
+ __all__ = [
39
+ "CONTRACT_VERSION",
40
+ "CandidateFix",
41
+ "CellFix",
42
+ "Issue",
43
+ "ProposedFix",
44
+ "RepairFailure",
45
+ "RepairPipelineRequest",
46
+ "RepairPipelineResult",
47
+ "RepairReceipt",
48
+ "RepairTransaction",
49
+ "SMTVerifier",
50
+ "SafetyContext",
51
+ "SafetyFilter",
52
+ "SafetyResult",
53
+ "SafetyVerdict",
54
+ "Schema",
55
+ "Severity",
56
+ "TransactionAuditReport",
57
+ "TransactionAuditVerdict",
58
+ "TransactionLogError",
59
+ "TransactionRevertError",
60
+ "VerificationResult",
61
+ "VerificationVerdict",
62
+ "VerifiedFix",
63
+ "__version__",
64
+ "load_schema",
65
+ "read_csv",
66
+ "revert_transaction",
67
+ "run_all_detectors",
68
+ "run_repair_pipeline",
69
+ "schema_from_mapping",
70
+ "verify_transaction_log",
71
+ ]
72
 
73
  __version__ = "0.1.0"
74
+
75
+ _PUBLIC_EXPORTS: dict[str, tuple[str, str]] = {
76
+ "CONTRACT_VERSION": ("dataforge.repair_contract", "CONTRACT_VERSION"),
77
+ "CandidateFix": ("dataforge.engine.repair", "CandidateFix"),
78
+ "CellFix": ("dataforge.transactions.txn", "CellFix"),
79
+ "Issue": ("dataforge.detectors", "Issue"),
80
+ "ProposedFix": ("dataforge.repairers", "ProposedFix"),
81
+ "RepairFailure": ("dataforge.engine.repair", "RepairFailure"),
82
+ "RepairPipelineRequest": ("dataforge.engine.repair", "RepairPipelineRequest"),
83
+ "RepairPipelineResult": ("dataforge.engine.repair", "RepairPipelineResult"),
84
+ "RepairReceipt": ("dataforge.engine.repair", "RepairReceipt"),
85
+ "RepairTransaction": ("dataforge.transactions.txn", "RepairTransaction"),
86
+ "SMTVerifier": ("dataforge.verifier", "SMTVerifier"),
87
+ "SafetyContext": ("dataforge.safety", "SafetyContext"),
88
+ "SafetyFilter": ("dataforge.safety", "SafetyFilter"),
89
+ "SafetyResult": ("dataforge.safety", "SafetyResult"),
90
+ "SafetyVerdict": ("dataforge.safety", "SafetyVerdict"),
91
+ "Schema": ("dataforge.detectors", "Schema"),
92
+ "Severity": ("dataforge.detectors", "Severity"),
93
+ "TransactionAuditReport": ("dataforge.transactions.log", "TransactionAuditReport"),
94
+ "TransactionAuditVerdict": ("dataforge.transactions.log", "TransactionAuditVerdict"),
95
+ "TransactionLogError": ("dataforge.transactions.log", "TransactionLogError"),
96
+ "TransactionRevertError": ("dataforge.transactions.revert", "TransactionRevertError"),
97
+ "VerificationResult": ("dataforge.verifier", "VerificationResult"),
98
+ "VerificationVerdict": ("dataforge.verifier", "VerificationVerdict"),
99
+ "VerifiedFix": ("dataforge.engine.repair", "VerifiedFix"),
100
+ "load_schema": ("dataforge.cli.common", "load_schema"),
101
+ "read_csv": ("dataforge.cli.common", "read_csv"),
102
+ "revert_transaction": ("dataforge.transactions.revert", "revert_transaction"),
103
+ "run_all_detectors": ("dataforge.detectors", "run_all_detectors"),
104
+ "run_repair_pipeline": ("dataforge.engine.repair", "run_repair_pipeline"),
105
+ "schema_from_mapping": ("dataforge.cli.common", "schema_from_mapping"),
106
+ "verify_transaction_log": ("dataforge.transactions.log", "verify_transaction_log"),
107
+ }
108
+
109
+
110
+ def __getattr__(name: str) -> Any:
111
+ """Resolve public facade exports on first use."""
112
+ try:
113
+ module_name, attribute_name = _PUBLIC_EXPORTS[name]
114
+ except KeyError as exc:
115
+ raise AttributeError(name) from exc
116
+ value = getattr(import_module(module_name), attribute_name)
117
+ globals()[name] = value
118
+ return value
dataforge/agent/__init__.py CHANGED
@@ -1 +1,16 @@
1
- """Agent package scaffolding for DataForge."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DataForge agent package typed tool-use actions and scratchpad.
2
+
3
+ Public API:
4
+ parse_action — Parse raw dict into typed Action model.
5
+ Action — Discriminated union of all action types.
6
+ Scratchpad — In-episode hypothesis tracker.
7
+ """
8
+
9
+ from dataforge.agent.scratchpad import Scratchpad
10
+ from dataforge.agent.tool_actions import Action, parse_action
11
+
12
+ __all__ = [
13
+ "Action",
14
+ "Scratchpad",
15
+ "parse_action",
16
+ ]
dataforge/agent/providers.py CHANGED
@@ -59,8 +59,9 @@ def get_provider_name() -> str:
59
  """Read the active provider from the environment.
60
 
61
  Returns:
62
- The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``,
63
- defaulting to ``"groq"`` if not set.
 
64
 
65
  Example:
66
  >>> import os
@@ -68,7 +69,14 @@ def get_provider_name() -> str:
68
  >>> get_provider_name()
69
  'gemini'
70
  """
71
- return os.environ.get("DATAFORGE_LLM_PROVIDER", "groq").lower()
 
 
 
 
 
 
 
72
 
73
 
74
  async def complete(
 
59
  """Read the active provider from the environment.
60
 
61
  Returns:
62
+ The lowercased provider name from ``DATAFORGE_LLM_PROVIDER``.
63
+ When no explicit provider is configured, prefer a provider whose
64
+ credential is present in the environment.
65
 
66
  Example:
67
  >>> import os
 
69
  >>> get_provider_name()
70
  'gemini'
71
  """
72
+ configured = os.environ.get("DATAFORGE_LLM_PROVIDER")
73
+ if configured:
74
+ return configured.lower()
75
+ if os.environ.get("GROQ_API_KEY"):
76
+ return "groq"
77
+ if os.environ.get("GEMINI_API_KEY"):
78
+ return "gemini"
79
+ return "groq"
80
 
81
 
82
  async def complete(
dataforge/agent/scratchpad.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """In-episode hypothesis and issue tracker for the DataForge RL agent.
2
+
3
+ The scratchpad is a mutable, episode-scoped data structure that the agent
4
+ uses to record hypotheses, confirmed issues, and dead ends. The environment
5
+ exposes a compact summary of the scratchpad in each observation, enabling
6
+ the agent to reason about its investigation history without direct access
7
+ to the underlying data structure.
8
+
9
+ Example::
10
+
11
+ >>> from dataforge.agent.scratchpad import Scratchpad
12
+ >>> pad = Scratchpad()
13
+ >>> pad.add_hypothesis("Rating column has decimal shift", [5], ["rating"], "decimal_shift")
14
+ >>> pad.confirm_issue(5, "rating", "decimal_shift")
15
+ >>> pad.summary()
16
+ 'Hypotheses: 1 (0 pending). Confirmed: 1. Dead ends: 0.'
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from dataclasses import dataclass, field
22
+
23
+ __all__ = [
24
+ "ConfirmedIssue",
25
+ "DeadEnd",
26
+ "HypothesisRecord",
27
+ "Scratchpad",
28
+ ]
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class HypothesisRecord:
33
+ """A recorded hypothesis about a data-quality root cause.
34
+
35
+ Args:
36
+ claim: Textual description of the hypothesis.
37
+ affected_rows: Row indices the hypothesis covers.
38
+ affected_columns: Column names the hypothesis covers.
39
+ root_cause_type: Detector-vocabulary root cause type.
40
+ confirmed: Whether the hypothesis was confirmed by ground truth.
41
+ """
42
+
43
+ claim: str
44
+ affected_rows: tuple[int, ...]
45
+ affected_columns: tuple[str, ...]
46
+ root_cause_type: str
47
+ confirmed: bool = False
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class ConfirmedIssue:
52
+ """A confirmed data-quality issue at a specific location.
53
+
54
+ Args:
55
+ row: Zero-indexed row number.
56
+ column: Column name.
57
+ issue_type: Issue type classification.
58
+ """
59
+
60
+ row: int
61
+ column: str
62
+ issue_type: str
63
+
64
+
65
+ @dataclass(frozen=True)
66
+ class DeadEnd:
67
+ """A recorded dead end — an investigation path that yielded nothing.
68
+
69
+ Args:
70
+ description: What was tried and why it failed.
71
+ step_number: Step at which the dead end was recorded.
72
+ """
73
+
74
+ description: str
75
+ step_number: int
76
+
77
+
78
+ @dataclass
79
+ class Scratchpad:
80
+ """Mutable in-episode tracker for hypotheses, confirmed issues, and dead ends.
81
+
82
+ Reset at the start of each episode. The ``summary()`` method produces a
83
+ compact string for inclusion in agent observations.
84
+
85
+ Example::
86
+
87
+ >>> pad = Scratchpad()
88
+ >>> pad.add_hypothesis("Decimal shift in rating", [5], ["rating"], "decimal_shift")
89
+ >>> len(pad.hypotheses)
90
+ 1
91
+ """
92
+
93
+ hypotheses: list[HypothesisRecord] = field(default_factory=list)
94
+ confirmed_issues: list[ConfirmedIssue] = field(default_factory=list)
95
+ dead_ends: list[DeadEnd] = field(default_factory=list)
96
+
97
+ def add_hypothesis(
98
+ self,
99
+ claim: str,
100
+ affected_rows: list[int],
101
+ affected_columns: list[str],
102
+ root_cause_type: str,
103
+ ) -> HypothesisRecord:
104
+ """Record a new hypothesis.
105
+
106
+ Args:
107
+ claim: Textual description of the hypothesis.
108
+ affected_rows: Row indices the hypothesis covers.
109
+ affected_columns: Column names the hypothesis covers.
110
+ root_cause_type: Detector-vocabulary root cause type.
111
+
112
+ Returns:
113
+ The recorded hypothesis.
114
+ """
115
+ record = HypothesisRecord(
116
+ claim=claim,
117
+ affected_rows=tuple(affected_rows),
118
+ affected_columns=tuple(affected_columns),
119
+ root_cause_type=root_cause_type,
120
+ )
121
+ self.hypotheses.append(record)
122
+ return record
123
+
124
+ def confirm_hypothesis(self, index: int) -> None:
125
+ """Mark a hypothesis as confirmed.
126
+
127
+ Args:
128
+ index: Index into the ``hypotheses`` list.
129
+
130
+ Raises:
131
+ IndexError: If the index is out of range.
132
+ """
133
+ old = self.hypotheses[index]
134
+ self.hypotheses[index] = HypothesisRecord(
135
+ claim=old.claim,
136
+ affected_rows=old.affected_rows,
137
+ affected_columns=old.affected_columns,
138
+ root_cause_type=old.root_cause_type,
139
+ confirmed=True,
140
+ )
141
+
142
+ def confirm_issue(self, row: int, column: str, issue_type: str) -> None:
143
+ """Record a confirmed issue.
144
+
145
+ Args:
146
+ row: Zero-indexed row number.
147
+ column: Column name.
148
+ issue_type: Issue type classification.
149
+ """
150
+ self.confirmed_issues.append(ConfirmedIssue(row=row, column=column, issue_type=issue_type))
151
+
152
+ def add_dead_end(self, description: str, step_number: int) -> None:
153
+ """Record a dead end.
154
+
155
+ Args:
156
+ description: What was tried and why it failed.
157
+ step_number: Step at which the dead end was recorded.
158
+ """
159
+ self.dead_ends.append(DeadEnd(description=description, step_number=step_number))
160
+
161
+ def reset(self) -> None:
162
+ """Clear all tracked state for a new episode."""
163
+ self.hypotheses.clear()
164
+ self.confirmed_issues.clear()
165
+ self.dead_ends.clear()
166
+
167
+ def summary(self) -> str:
168
+ """Produce a compact summary string for observation embedding.
169
+
170
+ Returns:
171
+ A one-line summary of scratchpad state.
172
+
173
+ Example::
174
+
175
+ >>> Scratchpad().summary()
176
+ 'Hypotheses: 0 (0 pending). Confirmed: 0. Dead ends: 0.'
177
+ """
178
+ pending = sum(1 for h in self.hypotheses if not h.confirmed)
179
+ return (
180
+ f"Hypotheses: {len(self.hypotheses)} ({pending} pending). "
181
+ f"Confirmed: {len(self.confirmed_issues)}. "
182
+ f"Dead ends: {len(self.dead_ends)}."
183
+ )
dataforge/agent/tool_actions.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Typed tool-use action models for the DataForge RL environment.
2
+
3
+ This module defines a discriminated union of 8 action types that an RL agent
4
+ can submit to the DataForge environment. Each action is a standalone Pydantic
5
+ model with its own validation rules, preventing cross-model field pollution.
6
+
7
+ The ``parse_action`` function is the single entry point for HTTP handlers
8
+ and tests to validate raw action dicts into typed models.
9
+
10
+ Action Types:
11
+ INSPECT_ROWS — View a slice of the dataset.
12
+ SQL_QUERY — Execute read-only SQL against the episode DataFrame.
13
+ STAT_TEST — Run a statistical test on a column.
14
+ PATTERN_MATCH — Evaluate a regex pattern against column values.
15
+ HYPOTHESIS — Record a causal-root claim for credit.
16
+ ROOT_CAUSE — Analyze selected detected errors for minimal roots.
17
+ DIAGNOSE — Flag a suspected issue at (row, column).
18
+ FIX — Propose a corrected value for a diagnosed issue.
19
+
20
+ Example::
21
+
22
+ >>> from dataforge.agent.tool_actions import parse_action
23
+ >>> action = parse_action({"action_type": "INSPECT_ROWS", "row_indices": [0, 1]})
24
+ >>> action.action_type
25
+ 'INSPECT_ROWS'
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ from typing import Annotated, Any, Literal
31
+
32
+ from pydantic import BaseModel, Field, field_validator
33
+
34
+ __all__ = [
35
+ "Action",
36
+ "Diagnose",
37
+ "Fix",
38
+ "Hypothesis",
39
+ "InspectRows",
40
+ "PatternMatch",
41
+ "RootCause",
42
+ "SqlQuery",
43
+ "StatTest",
44
+ "parse_action",
45
+ ]
46
+
47
+
48
+ class InspectRows(BaseModel):
49
+ """View a slice of dataset rows.
50
+
51
+ Args:
52
+ action_type: Must be ``"INSPECT_ROWS"``.
53
+ row_indices: Zero-indexed row indices to retrieve. At least 1 required.
54
+ column_names: Optional column filter. If omitted, all columns returned.
55
+
56
+ Example::
57
+
58
+ >>> InspectRows(action_type="INSPECT_ROWS", row_indices=[0, 1, 2])
59
+ """
60
+
61
+ action_type: Literal["INSPECT_ROWS"]
62
+ row_indices: list[int] = Field(min_length=1, description="Row indices to inspect (0-indexed).")
63
+ column_names: list[str] | None = Field(default=None, description="Optional column filter.")
64
+
65
+ @field_validator("row_indices")
66
+ @classmethod
67
+ def _validate_row_indices(cls, v: list[int]) -> list[int]:
68
+ """Validate that all row indices are non-negative."""
69
+ if any(i < 0 for i in v):
70
+ raise ValueError("All row indices must be >= 0")
71
+ return v
72
+
73
+ model_config = {"frozen": True}
74
+
75
+
76
+ class SqlQuery(BaseModel):
77
+ """Execute read-only SQL against the episode DataFrame via DuckDB.
78
+
79
+ Args:
80
+ action_type: Must be ``"SQL_QUERY"``.
81
+ query: SQL query string. Must be read-only (SELECT only).
82
+
83
+ Example::
84
+
85
+ >>> SqlQuery(action_type="SQL_QUERY", query="SELECT * FROM data LIMIT 5")
86
+ """
87
+
88
+ action_type: Literal["SQL_QUERY"]
89
+ query: str = Field(min_length=1, description="Read-only SQL query.")
90
+
91
+ model_config = {"frozen": True}
92
+
93
+
94
+ class StatTest(BaseModel):
95
+ """Run a statistical test on a dataset column.
96
+
97
+ Args:
98
+ action_type: Must be ``"STAT_TEST"``.
99
+ test_type: One of ``"zscore"``, ``"iqr"``, ``"ks"``.
100
+ column: Column name to test.
101
+ threshold: Optional threshold override. Defaults vary by test type.
102
+
103
+ Example::
104
+
105
+ >>> StatTest(action_type="STAT_TEST", test_type="zscore", column="rating")
106
+ """
107
+
108
+ action_type: Literal["STAT_TEST"]
109
+ test_type: Literal["zscore", "iqr", "ks"] = Field(description="Statistical test to run.")
110
+ column: str = Field(min_length=1, description="Column name to test.")
111
+ threshold: float | None = Field(default=None, description="Optional threshold override.")
112
+
113
+ model_config = {"frozen": True}
114
+
115
+
116
+ class PatternMatch(BaseModel):
117
+ """Evaluate a regex pattern against column values.
118
+
119
+ Args:
120
+ action_type: Must be ``"PATTERN_MATCH"``.
121
+ pattern: Regular expression string.
122
+ column: Column name to evaluate.
123
+ expect_match: If True, report rows that match. If False, report non-matches.
124
+
125
+ Example::
126
+
127
+ >>> PatternMatch(
128
+ ... action_type="PATTERN_MATCH",
129
+ ... pattern=r"^\\d{5}$",
130
+ ... column="zip_code",
131
+ ... )
132
+ """
133
+
134
+ action_type: Literal["PATTERN_MATCH"]
135
+ pattern: str = Field(min_length=1, description="Regex pattern.")
136
+ column: str = Field(min_length=1, description="Column name to evaluate.")
137
+ expect_match: bool = Field(
138
+ default=True,
139
+ description="True to report matches, False to report non-matches.",
140
+ )
141
+
142
+ model_config = {"frozen": True}
143
+
144
+
145
+ class Hypothesis(BaseModel):
146
+ """Record a causal-root claim for root-cause credit.
147
+
148
+ Args:
149
+ action_type: Must be ``"HYPOTHESIS"``.
150
+ claim: Textual description of the hypothesized root cause.
151
+ affected_rows: Row indices believed to be affected.
152
+ affected_columns: Column names believed to be affected.
153
+ root_cause_type: Detector-vocabulary root cause type
154
+ (e.g., ``"decimal_shift"``, ``"type_mismatch"``).
155
+
156
+ Example::
157
+
158
+ >>> Hypothesis(
159
+ ... action_type="HYPOTHESIS",
160
+ ... claim="Column 'rating' has a decimal shift at row 5",
161
+ ... affected_rows=[5],
162
+ ... affected_columns=["rating"],
163
+ ... root_cause_type="decimal_shift",
164
+ ... )
165
+ """
166
+
167
+ action_type: Literal["HYPOTHESIS"]
168
+ claim: str = Field(min_length=1, description="Root-cause claim.")
169
+ affected_rows: list[int] = Field(min_length=1, description="Affected row indices.")
170
+ affected_columns: list[str] = Field(min_length=1, description="Affected column names.")
171
+ root_cause_type: str = Field(min_length=1, description="Detector-vocabulary root cause type.")
172
+
173
+ @field_validator("affected_rows")
174
+ @classmethod
175
+ def _validate_affected_rows(cls, v: list[int]) -> list[int]:
176
+ """Validate that all affected row indices are non-negative."""
177
+ if any(i < 0 for i in v):
178
+ raise ValueError("All affected row indices must be >= 0")
179
+ return v
180
+
181
+ model_config = {"frozen": True}
182
+
183
+
184
+ class RootCause(BaseModel):
185
+ """Analyze selected detected errors for minimal causal roots.
186
+
187
+ Args:
188
+ action_type: Must be ``"ROOT_CAUSE"``.
189
+ error_indices: Zero-based indices into the episode's detected issue list.
190
+
191
+ Example::
192
+
193
+ >>> RootCause(action_type="ROOT_CAUSE", error_indices=[0, 1])
194
+ """
195
+
196
+ action_type: Literal["ROOT_CAUSE"]
197
+ error_indices: list[int] = Field(min_length=1, description="Detected issue indices.")
198
+
199
+ @field_validator("error_indices")
200
+ @classmethod
201
+ def _validate_error_indices(cls, v: list[int]) -> list[int]:
202
+ """Validate that all error indices are non-negative."""
203
+ if any(i < 0 for i in v):
204
+ raise ValueError("All error indices must be >= 0")
205
+ return v
206
+
207
+ model_config = {"frozen": True}
208
+
209
+
210
+ class Diagnose(BaseModel):
211
+ """Flag a suspected data-quality issue at a specific (row, column).
212
+
213
+ Args:
214
+ action_type: Must be ``"DIAGNOSE"``.
215
+ row: Zero-indexed row number.
216
+ column: Column name.
217
+ issue_type: Issue type from detector vocabulary.
218
+
219
+ Example::
220
+
221
+ >>> Diagnose(
222
+ ... action_type="DIAGNOSE",
223
+ ... row=5, column="rating",
224
+ ... issue_type="decimal_shift",
225
+ ... )
226
+ """
227
+
228
+ action_type: Literal["DIAGNOSE"]
229
+ row: int = Field(ge=0, description="Zero-indexed row number.")
230
+ column: str = Field(min_length=1, description="Column name.")
231
+ issue_type: str = Field(min_length=1, description="Issue type classification.")
232
+
233
+ model_config = {"frozen": True}
234
+
235
+
236
+ class Fix(BaseModel):
237
+ """Propose a corrected value for a diagnosed issue.
238
+
239
+ Args:
240
+ action_type: Must be ``"FIX"``.
241
+ row: Zero-indexed row number.
242
+ column: Column name.
243
+ new_value: The corrected cell value as a string.
244
+ justification: Explanation of why this fix is correct.
245
+ fix_type: How to fix the issue. Defaults to ``"correct_value"``.
246
+
247
+ Example::
248
+
249
+ >>> Fix(
250
+ ... action_type="FIX",
251
+ ... row=5, column="rating",
252
+ ... new_value="4.5",
253
+ ... justification="Decimal shift: 45.0 should be 4.5",
254
+ ... )
255
+ """
256
+
257
+ action_type: Literal["FIX"]
258
+ row: int = Field(ge=0, description="Zero-indexed row number.")
259
+ column: str = Field(min_length=1, description="Column name.")
260
+ new_value: str = Field(description="Corrected cell value.")
261
+ justification: str = Field(min_length=1, description="Fix justification.")
262
+ fix_type: Literal["correct_value", "delete_row", "impute", "standardize"] = Field(
263
+ default="correct_value", description="Fix operation type."
264
+ )
265
+
266
+ model_config = {"frozen": True}
267
+
268
+
269
+ # ═══════════════════════════════════════════════════════════════════════════
270
+ # Discriminated union and parser
271
+ # ═══════════════════════════════════════════════════════════════════════════
272
+
273
+ Action = Annotated[
274
+ InspectRows | SqlQuery | StatTest | PatternMatch | Hypothesis | RootCause | Diagnose | Fix,
275
+ Field(discriminator="action_type"),
276
+ ]
277
+ """Discriminated union of all valid DataForge environment actions."""
278
+
279
+
280
+ def parse_action(raw: dict[str, Any]) -> Action:
281
+ """Parse and validate a raw action dict into the appropriate typed model.
282
+
283
+ This is the single entry point for HTTP handlers and tests to validate
284
+ actions. The ``action_type`` field is used as the discriminator.
285
+
286
+ Args:
287
+ raw: Dictionary with an ``action_type`` key and action-specific fields.
288
+
289
+ Returns:
290
+ A validated action model instance.
291
+
292
+ Raises:
293
+ pydantic.ValidationError: If the action is malformed or invalid.
294
+ KeyError: If ``action_type`` is missing.
295
+ ValueError: If ``action_type`` is not recognized.
296
+
297
+ Example::
298
+
299
+ >>> action = parse_action({"action_type": "INSPECT_ROWS", "row_indices": [0]})
300
+ >>> isinstance(action, InspectRows)
301
+ True
302
+ """
303
+ from pydantic import TypeAdapter
304
+
305
+ adapter: TypeAdapter[Action] = TypeAdapter(Action)
306
+ return adapter.validate_python(_normalize_action(raw))
307
+
308
+
309
+ def _normalize_action(raw: dict[str, Any]) -> dict[str, Any]:
310
+ """Return a canonical action dictionary from supported external aliases."""
311
+ normalized = dict(raw)
312
+ action_type = normalized.get("action_type")
313
+ if action_type == "SQL_QUERY" and "sql" in normalized and "query" not in normalized:
314
+ normalized["query"] = normalized["sql"]
315
+ if action_type == "STAT_TEST" and "test" in normalized and "test_type" not in normalized:
316
+ normalized["test_type"] = normalized["test"]
317
+ if action_type == "PATTERN_MATCH":
318
+ if "regex" in normalized and "pattern" not in normalized:
319
+ normalized["pattern"] = normalized["regex"]
320
+ if "expect" in normalized and "expect_match" not in normalized:
321
+ normalized["expect_match"] = normalized["expect"] == "match"
322
+ if action_type == "HYPOTHESIS":
323
+ root_column = normalized.get("root_column")
324
+ downstream = normalized.get("downstream")
325
+ if root_column is not None and "affected_columns" not in normalized:
326
+ downstream_columns = downstream if isinstance(downstream, list) else []
327
+ normalized["affected_columns"] = [root_column, *downstream_columns]
328
+ if "affected_rows" not in normalized:
329
+ normalized["affected_rows"] = [0]
330
+ if root_column is not None and "root_cause_type" not in normalized:
331
+ normalized["root_cause_type"] = root_column
332
+ if (
333
+ action_type == "ROOT_CAUSE"
334
+ and "indices" in normalized
335
+ and "error_indices" not in normalized
336
+ ):
337
+ normalized["error_indices"] = normalized["indices"]
338
+ if action_type == "FIX":
339
+ if "proposed_value" in normalized and "new_value" not in normalized:
340
+ normalized["new_value"] = normalized["proposed_value"]
341
+ if "justification" not in normalized:
342
+ normalized["justification"] = "Agent proposed value via FIX."
343
+ return normalized
dataforge/bench/core.py CHANGED
@@ -59,6 +59,7 @@ class SeedBenchmarkResult(BaseModel):
59
  prompt_tokens: int = Field(ge=0, default=0)
60
  completion_tokens: int = Field(ge=0, default=0)
61
  quota_units: float = Field(ge=0.0, default=0.0)
 
62
  runtime_s: float = Field(ge=0.0, default=0.0)
63
  provider: str | None = None
64
  model: str | None = None
@@ -85,6 +86,8 @@ class AggregateBenchmarkResult(BaseModel):
85
  avg_steps_std: float | None = None
86
  quota_units_mean: float | None = None
87
  quota_units_std: float | None = None
 
 
88
  runtime_s_mean: float | None = None
89
  runtime_s_std: float | None = None
90
  provider: str | None = None
@@ -229,6 +232,7 @@ def aggregate_seed_results(
229
  f1_mean, f1_std = _mean_std([row.f1 or 0.0 for row in ok_rows])
230
  avg_steps_mean, avg_steps_std = _mean_std([row.avg_steps or 0.0 for row in ok_rows])
231
  quota_mean, quota_std = _mean_std([row.quota_units for row in ok_rows])
 
232
  runtime_mean, runtime_std = _mean_std([row.runtime_s for row in ok_rows])
233
  aggregates.append(
234
  AggregateBenchmarkResult(
@@ -248,6 +252,8 @@ def aggregate_seed_results(
248
  avg_steps_std=avg_steps_std,
249
  quota_units_mean=quota_mean,
250
  quota_units_std=quota_std,
 
 
251
  runtime_s_mean=runtime_mean,
252
  runtime_s_std=runtime_std,
253
  provider=ok_rows[0].provider,
 
59
  prompt_tokens: int = Field(ge=0, default=0)
60
  completion_tokens: int = Field(ge=0, default=0)
61
  quota_units: float = Field(ge=0.0, default=0.0)
62
+ gpu_hours: float = Field(ge=0.0, default=0.0)
63
  runtime_s: float = Field(ge=0.0, default=0.0)
64
  provider: str | None = None
65
  model: str | None = None
 
86
  avg_steps_std: float | None = None
87
  quota_units_mean: float | None = None
88
  quota_units_std: float | None = None
89
+ gpu_hours_mean: float | None = None
90
+ gpu_hours_std: float | None = None
91
  runtime_s_mean: float | None = None
92
  runtime_s_std: float | None = None
93
  provider: str | None = None
 
232
  f1_mean, f1_std = _mean_std([row.f1 or 0.0 for row in ok_rows])
233
  avg_steps_mean, avg_steps_std = _mean_std([row.avg_steps or 0.0 for row in ok_rows])
234
  quota_mean, quota_std = _mean_std([row.quota_units for row in ok_rows])
235
+ gpu_hours_mean, gpu_hours_std = _mean_std([row.gpu_hours for row in ok_rows])
236
  runtime_mean, runtime_std = _mean_std([row.runtime_s for row in ok_rows])
237
  aggregates.append(
238
  AggregateBenchmarkResult(
 
252
  avg_steps_std=avg_steps_std,
253
  quota_units_mean=quota_mean,
254
  quota_units_std=quota_std,
255
+ gpu_hours_mean=gpu_hours_mean,
256
+ gpu_hours_std=gpu_hours_std,
257
  runtime_s_mean=runtime_mean,
258
  runtime_s_std=runtime_std,
259
  provider=ok_rows[0].provider,
dataforge/bench/groq_client.py CHANGED
@@ -1,21 +1,45 @@
1
- """Minimal Groq client for benchmark-only LLM baselines."""
2
 
3
  from __future__ import annotations
4
 
5
  import json
 
6
  import time
7
  from dataclasses import dataclass
8
  from typing import cast
9
 
10
  import httpx
11
- from tenacity import retry, retry_if_exception, stop_after_attempt, wait_fixed
 
 
 
 
 
 
 
12
 
13
 
14
  def _is_rate_limit_error(exc: BaseException) -> bool:
15
- """Return whether an exception is a Groq 429 response."""
16
  return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @dataclass(frozen=True, kw_only=True)
20
  class GroqCompletion:
21
  """Completion payload plus conservative usage accounting."""
@@ -26,26 +50,50 @@ class GroqCompletion:
26
  warnings: tuple[str, ...]
27
 
28
 
29
- class GroqBenchClient:
30
- """Sequential Groq client with fixed 429 retry and spacing."""
31
 
32
  def __init__(
33
  self,
34
  *,
35
  api_key: str,
36
- model: str = "llama-3.3-70b-versatile",
 
 
37
  min_interval_s: float = 2.0,
 
 
 
 
38
  ) -> None:
39
  self._api_key = api_key
40
  self._model = model
 
 
41
  self._min_interval_s = min_interval_s
 
 
 
 
42
  self._last_success_at: float | None = None
 
 
 
 
 
 
 
43
 
44
  @property
45
  def model(self) -> str:
46
- """Return the configured Groq model name."""
47
  return self._model
48
 
 
 
 
 
 
49
  def _respect_spacing(self) -> None:
50
  """Sleep long enough to keep requests sequential with a fixed gap."""
51
  if self._last_success_at is None:
@@ -55,33 +103,57 @@ class GroqBenchClient:
55
  if remaining > 0:
56
  time.sleep(remaining)
57
 
58
- @retry(
59
- retry=retry_if_exception(_is_rate_limit_error),
60
- wait=wait_fixed(2),
61
- stop=stop_after_attempt(3),
62
- reraise=True,
63
- )
64
  def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
65
- """Issue the underlying Groq chat-completions request."""
66
  payload = {
67
  "model": self._model,
68
  "messages": messages,
69
  "temperature": 0.0,
 
70
  }
71
- with httpx.Client(timeout=60.0) as client:
72
- response = client.post(
73
- "https://api.groq.com/openai/v1/chat/completions",
74
- json=payload,
75
- headers={
76
- "Authorization": f"Bearer {self._api_key}",
77
- "Content-Type": "application/json",
78
- },
79
- )
80
- response.raise_for_status()
81
- return dict(response.json())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
84
- """Send one benchmark completion request to Groq."""
85
  self._respect_spacing()
86
  payload = self._post(messages)
87
  self._last_success_at = time.monotonic()
@@ -92,16 +164,223 @@ class GroqBenchClient:
92
  completion_tokens = int(usage.get("completion_tokens", 0)) if isinstance(usage, dict) else 0
93
  if not usage:
94
  warnings.append("missing_usage_payload")
 
 
 
95
 
96
  try:
97
  choices = cast(list[dict[str, object]], payload["choices"])
98
  message = cast(dict[str, object], choices[0]["message"])
99
  content = str(message["content"])
100
  except (KeyError, IndexError, TypeError) as exc:
101
- raise ValueError(f"Unexpected Groq response payload: {json.dumps(payload)}") from exc
 
 
102
  return GroqCompletion(
103
  text=content,
104
  prompt_tokens=prompt_tokens,
105
  completion_tokens=completion_tokens,
106
  warnings=tuple(warnings),
107
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal OpenAI-compatible clients for benchmark-only LLM baselines."""
2
 
3
  from __future__ import annotations
4
 
5
  import json
6
+ import logging
7
  import time
8
  from dataclasses import dataclass
9
  from typing import cast
10
 
11
  import httpx
12
+
13
+
14
+ class ProviderRequestError(RuntimeError):
15
+ """Raised when a provider rejects a benchmark request payload."""
16
+
17
+
18
+ class ProviderRateLimitError(ProviderRequestError):
19
+ """Raised when a provider asks us to wait longer than the configured cap."""
20
 
21
 
22
  def _is_rate_limit_error(exc: BaseException) -> bool:
23
+ """Return whether an exception is an HTTP 429 response."""
24
  return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429
25
 
26
 
27
+ def _is_retryable_provider_error(exc: BaseException) -> bool:
28
+ """Return whether an HTTP error is worth retrying for teacher collection."""
29
+ return isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code in {429, 503}
30
+
31
+
32
+ def _retry_after_s(exc: httpx.HTTPStatusError, *, fallback_s: float) -> float:
33
+ """Return provider retry-after delay when present."""
34
+ raw_retry_after = exc.response.headers.get("retry-after")
35
+ if raw_retry_after is None:
36
+ return fallback_s
37
+ try:
38
+ return max(float(raw_retry_after), fallback_s)
39
+ except ValueError:
40
+ return fallback_s
41
+
42
+
43
  @dataclass(frozen=True, kw_only=True)
44
  class GroqCompletion:
45
  """Completion payload plus conservative usage accounting."""
 
50
  warnings: tuple[str, ...]
51
 
52
 
53
+ class OpenAICompatBenchClient:
54
+ """Sequential OpenAI-compatible client with fixed 429 retry and spacing."""
55
 
56
  def __init__(
57
  self,
58
  *,
59
  api_key: str,
60
+ model: str,
61
+ endpoint: str,
62
+ provider: str,
63
  min_interval_s: float = 2.0,
64
+ max_tokens: int = 512,
65
+ max_retries: int = 5,
66
+ max_retry_after_s: float = 120.0,
67
+ timeout_s: float = 60.0,
68
  ) -> None:
69
  self._api_key = api_key
70
  self._model = model
71
+ self._endpoint = endpoint
72
+ self._provider = provider
73
  self._min_interval_s = min_interval_s
74
+ self._max_tokens = max_tokens
75
+ self._max_retries = max_retries
76
+ self._max_retry_after_s = max_retry_after_s
77
+ self._timeout_s = timeout_s
78
  self._last_success_at: float | None = None
79
+ self._client = httpx.Client(
80
+ timeout=self._timeout_s,
81
+ headers={
82
+ "Authorization": f"Bearer {self._api_key}",
83
+ "Content-Type": "application/json",
84
+ },
85
+ )
86
 
87
  @property
88
  def model(self) -> str:
89
+ """Return the configured provider model name."""
90
  return self._model
91
 
92
+ @property
93
+ def provider(self) -> str:
94
+ """Return the configured provider identifier."""
95
+ return self._provider
96
+
97
  def _respect_spacing(self) -> None:
98
  """Sleep long enough to keep requests sequential with a fixed gap."""
99
  if self._last_success_at is None:
 
103
  if remaining > 0:
104
  time.sleep(remaining)
105
 
 
 
 
 
 
 
106
  def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
107
+ """Issue the underlying chat-completions request."""
108
  payload = {
109
  "model": self._model,
110
  "messages": messages,
111
  "temperature": 0.0,
112
+ "max_tokens": self._max_tokens,
113
  }
114
+ last_rate_limit_error: httpx.HTTPStatusError | None = None
115
+ for attempt in range(self._max_retries):
116
+ response: httpx.Response | None = None
117
+ try:
118
+ response = self._client.post(
119
+ self._endpoint,
120
+ json=payload,
121
+ )
122
+ response.raise_for_status()
123
+ except httpx.HTTPStatusError as exc:
124
+ if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1:
125
+ body = exc.response.text[:500].replace("\n", " ")
126
+ raise ProviderRequestError(
127
+ f"{self._provider} request rejected with HTTP "
128
+ f"{exc.response.status_code}: {body}"
129
+ ) from exc
130
+ last_rate_limit_error = exc
131
+ retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1))
132
+ if retry_s > self._max_retry_after_s:
133
+ body = exc.response.text[:500].replace("\n", " ")
134
+ raise ProviderRateLimitError(
135
+ f"{self._provider} rate limit retry-after {retry_s:.2f}s "
136
+ f"exceeds cap {self._max_retry_after_s:.2f}s: {body}"
137
+ ) from exc
138
+ logging.getLogger("dataforge.bench.groq_client").warning(
139
+ "%s_rate_limit attempt=%d retry_after_s=%.2f",
140
+ self._provider,
141
+ attempt + 1,
142
+ retry_s,
143
+ )
144
+ time.sleep(retry_s)
145
+ continue
146
+ except httpx.TimeoutException as exc:
147
+ raise TimeoutError(
148
+ f"{self._provider} request timed out after {self._timeout_s:.1f} seconds."
149
+ ) from exc
150
+ return dict(response.json())
151
+ if last_rate_limit_error is not None:
152
+ raise last_rate_limit_error
153
+ raise RuntimeError(f"{self._provider} request failed without a response.")
154
 
155
  def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
156
+ """Send one benchmark completion request to the configured provider."""
157
  self._respect_spacing()
158
  payload = self._post(messages)
159
  self._last_success_at = time.monotonic()
 
164
  completion_tokens = int(usage.get("completion_tokens", 0)) if isinstance(usage, dict) else 0
165
  if not usage:
166
  warnings.append("missing_usage_payload")
167
+ logging.getLogger("dataforge.bench.groq_client").warning(
168
+ "%s_missing_usage_payload", self._provider
169
+ )
170
 
171
  try:
172
  choices = cast(list[dict[str, object]], payload["choices"])
173
  message = cast(dict[str, object], choices[0]["message"])
174
  content = str(message["content"])
175
  except (KeyError, IndexError, TypeError) as exc:
176
+ raise ValueError(
177
+ f"Unexpected {self._provider} response payload: {json.dumps(payload)}"
178
+ ) from exc
179
  return GroqCompletion(
180
  text=content,
181
  prompt_tokens=prompt_tokens,
182
  completion_tokens=completion_tokens,
183
  warnings=tuple(warnings),
184
  )
185
+
186
+
187
+ class GroqBenchClient(OpenAICompatBenchClient):
188
+ """Sequential Groq client with fixed 429 retry and spacing."""
189
+
190
+ def __init__(
191
+ self,
192
+ *,
193
+ api_key: str,
194
+ model: str = "llama-3.3-70b-versatile",
195
+ min_interval_s: float = 2.0,
196
+ max_tokens: int = 512,
197
+ max_retries: int = 5,
198
+ max_retry_after_s: float = 120.0,
199
+ timeout_s: float = 60.0,
200
+ ) -> None:
201
+ super().__init__(
202
+ api_key=api_key,
203
+ model=model,
204
+ endpoint="https://api.groq.com/openai/v1/chat/completions",
205
+ provider="groq",
206
+ min_interval_s=min_interval_s,
207
+ max_tokens=max_tokens,
208
+ max_retries=max_retries,
209
+ max_retry_after_s=max_retry_after_s,
210
+ timeout_s=timeout_s,
211
+ )
212
+
213
+
214
+ class CerebrasBenchClient(OpenAICompatBenchClient):
215
+ """Sequential Cerebras client with fixed 429 retry and spacing."""
216
+
217
+ def __init__(
218
+ self,
219
+ *,
220
+ api_key: str,
221
+ model: str = "qwen-3-235b-a22b-instruct-2507",
222
+ min_interval_s: float = 0.5,
223
+ max_tokens: int = 512,
224
+ max_retries: int = 5,
225
+ max_retry_after_s: float = 120.0,
226
+ timeout_s: float = 60.0,
227
+ ) -> None:
228
+ super().__init__(
229
+ api_key=api_key,
230
+ model=model,
231
+ endpoint="https://api.cerebras.ai/v1/chat/completions",
232
+ provider="cerebras",
233
+ min_interval_s=min_interval_s,
234
+ max_tokens=max_tokens,
235
+ max_retries=max_retries,
236
+ max_retry_after_s=max_retry_after_s,
237
+ timeout_s=timeout_s,
238
+ )
239
+
240
+
241
+ class GeminiBenchClient:
242
+ """Sequential Gemini client adapted to the benchmark completion interface."""
243
+
244
+ def __init__(
245
+ self,
246
+ *,
247
+ api_key: str,
248
+ model: str = "gemini-3.1-pro-preview",
249
+ min_interval_s: float = 2.0,
250
+ max_tokens: int = 512,
251
+ max_retries: int = 5,
252
+ max_retry_after_s: float = 120.0,
253
+ timeout_s: float = 60.0,
254
+ ) -> None:
255
+ self._api_key = api_key
256
+ self._model = model.removeprefix("models/")
257
+ self._min_interval_s = min_interval_s
258
+ self._max_tokens = max_tokens
259
+ self._max_retries = max_retries
260
+ self._max_retry_after_s = max_retry_after_s
261
+ self._timeout_s = timeout_s
262
+ self._last_success_at: float | None = None
263
+ self._client = httpx.Client(
264
+ timeout=self._timeout_s,
265
+ headers={"Content-Type": "application/json"},
266
+ )
267
+
268
+ @property
269
+ def model(self) -> str:
270
+ """Return the configured Gemini model name."""
271
+ return self._model
272
+
273
+ @property
274
+ def provider(self) -> str:
275
+ """Return the provider identifier."""
276
+ return "gemini"
277
+
278
+ def _respect_spacing(self) -> None:
279
+ """Sleep long enough to keep requests sequential with a fixed gap."""
280
+ if self._last_success_at is None:
281
+ return
282
+ elapsed = time.monotonic() - self._last_success_at
283
+ remaining = self._min_interval_s - elapsed
284
+ if remaining > 0:
285
+ time.sleep(remaining)
286
+
287
+ def _payload(self, messages: list[dict[str, str]]) -> dict[str, object]:
288
+ """Convert OpenAI-style chat messages to Gemini generateContent payload."""
289
+ system_texts: list[str] = []
290
+ contents: list[dict[str, object]] = []
291
+ for message in messages:
292
+ role = message.get("role", "user")
293
+ content = message.get("content", "")
294
+ if role == "system":
295
+ system_texts.append(content)
296
+ continue
297
+ gemini_role = "model" if role == "assistant" else "user"
298
+ contents.append({"role": gemini_role, "parts": [{"text": content}]})
299
+
300
+ payload: dict[str, object] = {
301
+ "contents": contents,
302
+ "generationConfig": {
303
+ "temperature": 0.0,
304
+ "maxOutputTokens": self._max_tokens,
305
+ },
306
+ }
307
+ if system_texts:
308
+ payload["systemInstruction"] = {
309
+ "parts": [{"text": "\n\n".join(system_texts)}],
310
+ }
311
+ return payload
312
+
313
+ def _post(self, messages: list[dict[str, str]]) -> dict[str, object]:
314
+ """Issue the underlying Gemini generateContent request."""
315
+ endpoint = (
316
+ f"https://generativelanguage.googleapis.com/v1beta/models/{self._model}:generateContent"
317
+ )
318
+ last_rate_limit_error: httpx.HTTPStatusError | None = None
319
+ for attempt in range(self._max_retries):
320
+ response: httpx.Response | None = None
321
+ try:
322
+ response = self._client.post(
323
+ endpoint,
324
+ params={"key": self._api_key},
325
+ json=self._payload(messages),
326
+ )
327
+ response.raise_for_status()
328
+ except httpx.HTTPStatusError as exc:
329
+ if not _is_retryable_provider_error(exc) or attempt == self._max_retries - 1:
330
+ body = exc.response.text[:500].replace("\n", " ")
331
+ raise ProviderRequestError(
332
+ f"gemini request rejected with HTTP {exc.response.status_code}: {body}"
333
+ ) from exc
334
+ last_rate_limit_error = exc
335
+ retry_s = _retry_after_s(exc, fallback_s=2.0 * (attempt + 1))
336
+ if retry_s > self._max_retry_after_s:
337
+ body = exc.response.text[:500].replace("\n", " ")
338
+ raise ProviderRateLimitError(
339
+ f"gemini rate limit retry-after {retry_s:.2f}s "
340
+ f"exceeds cap {self._max_retry_after_s:.2f}s: {body}"
341
+ ) from exc
342
+ logging.getLogger("dataforge.bench.groq_client").warning(
343
+ "gemini_rate_limit attempt=%d retry_after_s=%.2f",
344
+ attempt + 1,
345
+ retry_s,
346
+ )
347
+ time.sleep(retry_s)
348
+ continue
349
+ except httpx.TimeoutException as exc:
350
+ raise TimeoutError(
351
+ f"gemini request timed out after {self._timeout_s:.1f} seconds."
352
+ ) from exc
353
+ return dict(response.json())
354
+ if last_rate_limit_error is not None:
355
+ raise last_rate_limit_error
356
+ raise RuntimeError("gemini request failed without a response.")
357
+
358
+ def complete(self, messages: list[dict[str, str]]) -> GroqCompletion:
359
+ """Send one benchmark completion request to Gemini."""
360
+ self._respect_spacing()
361
+ payload = self._post(messages)
362
+ self._last_success_at = time.monotonic()
363
+
364
+ warnings: list[str] = []
365
+ usage = payload.get("usageMetadata", {})
366
+ prompt_tokens = int(usage.get("promptTokenCount", 0)) if isinstance(usage, dict) else 0
367
+ completion_tokens = (
368
+ int(usage.get("candidatesTokenCount", 0)) if isinstance(usage, dict) else 0
369
+ )
370
+ if not usage:
371
+ warnings.append("missing_usage_payload")
372
+ logging.getLogger("dataforge.bench.groq_client").warning("gemini_missing_usage_payload")
373
+
374
+ try:
375
+ candidates = cast(list[dict[str, object]], payload["candidates"])
376
+ content = cast(dict[str, object], candidates[0]["content"])
377
+ parts = cast(list[dict[str, object]], content["parts"])
378
+ text = "".join(str(part.get("text", "")) for part in parts)
379
+ except (KeyError, IndexError, TypeError) as exc:
380
+ raise ValueError(f"Unexpected gemini response payload: {json.dumps(payload)}") from exc
381
+ return GroqCompletion(
382
+ text=text,
383
+ prompt_tokens=prompt_tokens,
384
+ completion_tokens=completion_tokens,
385
+ warnings=tuple(warnings),
386
+ )
dataforge/bench/methods.py CHANGED
@@ -151,6 +151,40 @@ def _column_stats(
151
  return stats
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def _extract_json_object(text: str) -> dict[str, object] | None:
155
  """Parse the first JSON object found in an LLM response string."""
156
  stripped = text.strip()
@@ -158,6 +192,7 @@ def _extract_json_object(text: str) -> dict[str, object] | None:
158
  stripped = stripped.strip("`")
159
  if stripped.lower().startswith("json"):
160
  stripped = stripped[4:].strip()
 
161
  decoder = json.JSONDecoder()
162
  for offset, char in enumerate(stripped):
163
  if char != "{":
 
151
  return stats
152
 
153
 
154
+ def _strip_json_line_comments(text: str) -> str:
155
+ """Remove JavaScript-style line comments outside JSON strings."""
156
+ result: list[str] = []
157
+ in_string = False
158
+ escaped = False
159
+ index = 0
160
+ while index < len(text):
161
+ char = text[index]
162
+ next_char = text[index + 1] if index + 1 < len(text) else ""
163
+ if in_string:
164
+ result.append(char)
165
+ if escaped:
166
+ escaped = False
167
+ elif char == "\\":
168
+ escaped = True
169
+ elif char == '"':
170
+ in_string = False
171
+ index += 1
172
+ continue
173
+ if char == '"':
174
+ in_string = True
175
+ result.append(char)
176
+ index += 1
177
+ continue
178
+ if char == "/" and next_char == "/":
179
+ index += 2
180
+ while index < len(text) and text[index] not in "\r\n":
181
+ index += 1
182
+ continue
183
+ result.append(char)
184
+ index += 1
185
+ return "".join(result)
186
+
187
+
188
  def _extract_json_object(text: str) -> dict[str, object] | None:
189
  """Parse the first JSON object found in an LLM response string."""
190
  stripped = text.strip()
 
192
  stripped = stripped.strip("`")
193
  if stripped.lower().startswith("json"):
194
  stripped = stripped[4:].strip()
195
+ stripped = _strip_json_line_comments(stripped)
196
  decoder = json.JSONDecoder()
197
  for offset, char in enumerate(stripped):
198
  if char != "{":
dataforge/bench/report.py CHANGED
@@ -69,13 +69,14 @@ def _aggregate_across_datasets(aggregates: list[AggregateBenchmarkResult]) -> li
69
  for method in methods:
70
  ok_rows = grouped.get(method, [])
71
  if not ok_rows:
72
- rows.append([method, "Skipped", "Skipped", "Skipped", "Skipped", "Skipped"])
73
  continue
74
  p_mean = sum(row.precision_mean or 0.0 for row in ok_rows) / len(ok_rows)
75
  r_mean = sum(row.recall_mean or 0.0 for row in ok_rows) / len(ok_rows)
76
  f_mean = sum(row.f1_mean or 0.0 for row in ok_rows) / len(ok_rows)
77
  step_mean = sum(row.avg_steps_mean or 0.0 for row in ok_rows) / len(ok_rows)
78
  quota_mean = sum(row.quota_units_mean or 0.0 for row in ok_rows) / len(ok_rows)
 
79
  rows.append(
80
  [
81
  method,
@@ -84,6 +85,7 @@ def _aggregate_across_datasets(aggregates: list[AggregateBenchmarkResult]) -> li
84
  f"{f_mean:.4f}",
85
  f"{step_mean:.2f}",
86
  f"{quota_mean:.4f}",
 
87
  ]
88
  )
89
  return rows
@@ -104,15 +106,13 @@ def build_readme_benchmark_block(agent_output: BenchmarkRunOutput, report_path:
104
  """Build the generated README benchmark summary block."""
105
  rows = _aggregate_across_datasets(agent_output.aggregates)
106
  table = _render_table(
107
- ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units"],
108
  rows,
109
  )
110
  skip_reasons = _collect_skip_reasons(agent_output.aggregates)
111
  skip_note = ""
112
  if skip_reasons:
113
- skip_note = (
114
- "\n\nSkipped methods in this run: " + "; ".join(skip_reasons)
115
- )
116
  return (
117
  "Generated from `eval/results/agent_comparison.json`.\n\n"
118
  f"{table}\n\n"
@@ -140,19 +140,28 @@ def render_benchmark_report(
140
  _format_metric(row.f1_mean, row.f1_std),
141
  _format_metric(row.avg_steps_mean, row.avg_steps_std),
142
  _format_metric(row.quota_units_mean, row.quota_units_std),
 
143
  ]
144
  for row in rows
145
  ]
146
  per_dataset_sections.append(
147
  f"### {dataset.title()}\n\n"
148
  + _render_table(
149
- ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units"],
 
 
 
 
 
 
 
 
150
  table_rows,
151
  )
152
  )
153
 
154
  local_summary = _render_table(
155
- ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units"],
156
  _aggregate_across_datasets(agent_output.aggregates),
157
  )
158
 
@@ -179,11 +188,7 @@ def render_benchmark_report(
179
  skip_reasons = _collect_skip_reasons(agent_output.aggregates)
180
  skip_note = ""
181
  if skip_reasons:
182
- skip_note = (
183
- "\nSkipped methods in this reproduced run: "
184
- + "; ".join(skip_reasons)
185
- + "\n"
186
- )
187
 
188
  method_values = agent_output.metadata.get("methods", [])
189
  dataset_values = agent_output.metadata.get("datasets", [])
@@ -203,6 +208,7 @@ def render_benchmark_report(
203
  f"- Datasets: {', '.join(datasets)}\n"
204
  f"- Seeds: {seeds}\n"
205
  "- Free-tier quota units: `max(llm_calls / 1000, (prompt_tokens + completion_tokens) / 100000)`\n"
 
206
  f"{skip_note}\n"
207
  "## Cross-Dataset Local Results\n\n"
208
  f"{local_summary}\n\n"
@@ -216,7 +222,7 @@ def render_benchmark_report(
216
  sota_rows,
217
  )
218
  + "\n\n## Methodology\n\n"
219
- + "Local rows are reproduced from generated JSON. Citation-only SOTA rows are copied from literature and are not rerun in this repository. Quota units are reported in free-tier fractions rather than dollars.\n"
220
  )
221
 
222
 
 
69
  for method in methods:
70
  ok_rows = grouped.get(method, [])
71
  if not ok_rows:
72
+ rows.append([method, "Skipped", "Skipped", "Skipped", "Skipped", "Skipped", "Skipped"])
73
  continue
74
  p_mean = sum(row.precision_mean or 0.0 for row in ok_rows) / len(ok_rows)
75
  r_mean = sum(row.recall_mean or 0.0 for row in ok_rows) / len(ok_rows)
76
  f_mean = sum(row.f1_mean or 0.0 for row in ok_rows) / len(ok_rows)
77
  step_mean = sum(row.avg_steps_mean or 0.0 for row in ok_rows) / len(ok_rows)
78
  quota_mean = sum(row.quota_units_mean or 0.0 for row in ok_rows) / len(ok_rows)
79
+ gpu_hours_mean = sum(row.gpu_hours_mean or 0.0 for row in ok_rows) / len(ok_rows)
80
  rows.append(
81
  [
82
  method,
 
85
  f"{f_mean:.4f}",
86
  f"{step_mean:.2f}",
87
  f"{quota_mean:.4f}",
88
+ f"{gpu_hours_mean:.4f}",
89
  ]
90
  )
91
  return rows
 
106
  """Build the generated README benchmark summary block."""
107
  rows = _aggregate_across_datasets(agent_output.aggregates)
108
  table = _render_table(
109
+ ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units", "GPU Hours"],
110
  rows,
111
  )
112
  skip_reasons = _collect_skip_reasons(agent_output.aggregates)
113
  skip_note = ""
114
  if skip_reasons:
115
+ skip_note = "\n\nSkipped methods in this run: " + "; ".join(skip_reasons)
 
 
116
  return (
117
  "Generated from `eval/results/agent_comparison.json`.\n\n"
118
  f"{table}\n\n"
 
140
  _format_metric(row.f1_mean, row.f1_std),
141
  _format_metric(row.avg_steps_mean, row.avg_steps_std),
142
  _format_metric(row.quota_units_mean, row.quota_units_std),
143
+ _format_metric(row.gpu_hours_mean, row.gpu_hours_std),
144
  ]
145
  for row in rows
146
  ]
147
  per_dataset_sections.append(
148
  f"### {dataset.title()}\n\n"
149
  + _render_table(
150
+ [
151
+ "Method",
152
+ "Precision",
153
+ "Recall",
154
+ "F1",
155
+ "Avg Steps",
156
+ "Quota Units",
157
+ "GPU Hours",
158
+ ],
159
  table_rows,
160
  )
161
  )
162
 
163
  local_summary = _render_table(
164
+ ["Method", "Precision", "Recall", "F1", "Avg Steps", "Quota Units", "GPU Hours"],
165
  _aggregate_across_datasets(agent_output.aggregates),
166
  )
167
 
 
188
  skip_reasons = _collect_skip_reasons(agent_output.aggregates)
189
  skip_note = ""
190
  if skip_reasons:
191
+ skip_note = "\nSkipped methods in this reproduced run: " + "; ".join(skip_reasons) + "\n"
 
 
 
 
192
 
193
  method_values = agent_output.metadata.get("methods", [])
194
  dataset_values = agent_output.metadata.get("datasets", [])
 
208
  f"- Datasets: {', '.join(datasets)}\n"
209
  f"- Seeds: {seeds}\n"
210
  "- Free-tier quota units: `max(llm_calls / 1000, (prompt_tokens + completion_tokens) / 100000)`\n"
211
+ "- GRPO compute cost is reported as free-tier GPU-hours, not dollars.\n"
212
  f"{skip_note}\n"
213
  "## Cross-Dataset Local Results\n\n"
214
  f"{local_summary}\n\n"
 
222
  sota_rows,
223
  )
224
  + "\n\n## Methodology\n\n"
225
+ + "Local rows are reproduced from generated JSON. Citation-only SOTA rows are copied from literature and are not rerun in this repository. LLM quota units are free-tier fractions; GRPO compute cost is GPU-hours, not dollars.\n"
226
  )
227
 
228
 
dataforge/bench/runner.py CHANGED
@@ -3,6 +3,7 @@
3
  from __future__ import annotations
4
 
5
  import os
 
6
  from pathlib import Path
7
 
8
  from dotenv import load_dotenv
@@ -96,18 +97,21 @@ def run_agent_comparison(
96
  output_json: Path,
97
  really_run_big_bench: bool,
98
  cache_root: Path | None = None,
 
99
  ) -> BenchmarkRunOutput:
100
  """Run the selected benchmark methods across real-world datasets."""
101
  load_dotenv()
102
  _validate_inputs(methods, datasets, seeds)
103
 
104
  estimated_calls = estimate_llm_calls(methods=methods, datasets=datasets, seeds=seeds)
 
 
105
  validate_estimated_calls(
106
  estimated_calls=estimated_calls,
107
  really_run_big_bench=really_run_big_bench,
108
  )
109
 
110
- reproduction_command = _reproduction_command(methods, datasets, seeds)
111
  records: list[SeedBenchmarkResult] = []
112
  loaded_datasets = {
113
  dataset_name: load_real_world_dataset(dataset_name, cache_root=cache_root)
@@ -116,16 +120,45 @@ def run_agent_comparison(
116
 
117
  llm_methods_requested = any(method.startswith("llm_") for method in methods)
118
  skip_reason = _llm_skip_reason() if llm_methods_requested else None
119
- client = (
120
- GroqBenchClient(api_key=os.environ["GROQ_API_KEY"])
121
- if llm_methods_requested and skip_reason is None
122
- else None
123
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  for dataset_name in datasets:
126
  dataset = loaded_datasets[dataset_name]
127
  for method in methods:
128
  for seed in range(seeds):
 
 
 
 
 
 
129
  if method == "random":
130
  result = run_random_episode(dataset, seed=seed)
131
  elif method == "heuristic":
@@ -159,6 +192,12 @@ def run_agent_comparison(
159
  if method == "heuristic":
160
  result = result.model_copy(update={"seed": seed})
161
  records.append(result)
 
 
 
 
 
 
162
 
163
  aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results(
164
  records, seeds_requested=seeds
 
3
  from __future__ import annotations
4
 
5
  import os
6
+ import sys
7
  from pathlib import Path
8
 
9
  from dotenv import load_dotenv
 
97
  output_json: Path,
98
  really_run_big_bench: bool,
99
  cache_root: Path | None = None,
100
+ reproduction_command: str | None = None,
101
  ) -> BenchmarkRunOutput:
102
  """Run the selected benchmark methods across real-world datasets."""
103
  load_dotenv()
104
  _validate_inputs(methods, datasets, seeds)
105
 
106
  estimated_calls = estimate_llm_calls(methods=methods, datasets=datasets, seeds=seeds)
107
+ # Validate call budget before any client instantiation or dataset loads that could
108
+ # trigger network access in tests with environment variables set.
109
  validate_estimated_calls(
110
  estimated_calls=estimated_calls,
111
  really_run_big_bench=really_run_big_bench,
112
  )
113
 
114
+ reproduction_command = reproduction_command or _reproduction_command(methods, datasets, seeds)
115
  records: list[SeedBenchmarkResult] = []
116
  loaded_datasets = {
117
  dataset_name: load_real_world_dataset(dataset_name, cache_root=cache_root)
 
120
 
121
  llm_methods_requested = any(method.startswith("llm_") for method in methods)
122
  skip_reason = _llm_skip_reason() if llm_methods_requested else None
123
+ client = None
124
+ if llm_methods_requested and skip_reason is None:
125
+ # Allow env-driven tuning for tiny CI checks.
126
+ model = os.environ.get("DATAFORGE_GROQ_MODEL", "llama-3.3-70b-versatile")
127
+ try:
128
+ min_interval_s = float(os.environ.get("DATAFORGE_GROQ_MIN_INTERVAL_S", "1.0"))
129
+ except ValueError:
130
+ min_interval_s = 1.0
131
+ try:
132
+ timeout_s = float(os.environ.get("DATAFORGE_GROQ_TIMEOUT_S", "30"))
133
+ except ValueError:
134
+ timeout_s = 30.0
135
+ try:
136
+ max_tokens = int(os.environ.get("DATAFORGE_GROQ_MAX_TOKENS", "256"))
137
+ except ValueError:
138
+ max_tokens = 256
139
+ try:
140
+ max_retries = int(os.environ.get("DATAFORGE_GROQ_MAX_RETRIES", "3"))
141
+ except ValueError:
142
+ max_retries = 3
143
+ client = GroqBenchClient(
144
+ api_key=os.environ["GROQ_API_KEY"],
145
+ model=model,
146
+ min_interval_s=min_interval_s,
147
+ max_tokens=max_tokens,
148
+ max_retries=max_retries,
149
+ timeout_s=timeout_s,
150
+ )
151
 
152
  for dataset_name in datasets:
153
  dataset = loaded_datasets[dataset_name]
154
  for method in methods:
155
  for seed in range(seeds):
156
+ if os.environ.get("DATAFORGE_BENCH_VERBOSE"):
157
+ print(
158
+ f"[dataforge bench] start method={method} dataset={dataset_name} seed={seed}",
159
+ file=sys.stderr,
160
+ flush=True,
161
+ )
162
  if method == "random":
163
  result = run_random_episode(dataset, seed=seed)
164
  elif method == "heuristic":
 
192
  if method == "heuristic":
193
  result = result.model_copy(update={"seed": seed})
194
  records.append(result)
195
+ if os.environ.get("DATAFORGE_BENCH_VERBOSE"):
196
+ print(
197
+ f"[dataforge bench] done method={method} dataset={dataset_name} seed={seed} status={result.status}",
198
+ file=sys.stderr,
199
+ flush=True,
200
+ )
201
 
202
  aggregates: list[AggregateBenchmarkResult] = aggregate_seed_results(
203
  records, seeds_requested=seeds
dataforge/causal/__init__.py CHANGED
@@ -1 +1,21 @@
1
- """Causal analysis package scaffolding for DataForge."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Causal analysis primitives for DataForge root-cause diagnosis."""
2
+
3
+ from dataforge.causal.dag import CausalDAG, CausalEdge
4
+ from dataforge.causal.pc import CausalDiscoveryResult, discover_causal_dag
5
+ from dataforge.causal.root_cause import (
6
+ CausalRootCauseAnalyzer,
7
+ ErrorEvidence,
8
+ RootCauseResult,
9
+ minimal_root_set,
10
+ )
11
+
12
+ __all__ = [
13
+ "CausalDAG",
14
+ "CausalDiscoveryResult",
15
+ "CausalEdge",
16
+ "CausalRootCauseAnalyzer",
17
+ "ErrorEvidence",
18
+ "RootCauseResult",
19
+ "discover_causal_dag",
20
+ "minimal_root_set",
21
+ ]
dataforge/causal/dag.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Column-level causal DAG utilities for root-cause analysis."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import networkx as nx # type: ignore[import-untyped]
9
+
10
+ __all__ = ["CausalDAG", "CausalEdge"]
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class CausalEdge:
15
+ """Metadata for a directed causal edge.
16
+
17
+ Args:
18
+ source: Source column name.
19
+ target: Target column name.
20
+ confidence: Confidence in the directed influence, from 0.0 to 1.0.
21
+ provenance: Human-readable source of the edge.
22
+ """
23
+
24
+ source: str
25
+ target: str
26
+ confidence: float
27
+ provenance: str
28
+
29
+
30
+ class CausalDAG:
31
+ """Acyclic directed graph whose nodes are dataset columns.
32
+
33
+ Args:
34
+ nodes: Optional initial column names.
35
+
36
+ Example:
37
+ >>> dag = CausalDAG(["discount_pct", "order_total"])
38
+ >>> dag.add_edge("discount_pct", "order_total", confidence=0.9, provenance="fd")
39
+ >>> dag.is_reachable("discount_pct", "order_total")
40
+ True
41
+ """
42
+
43
+ def __init__(self, nodes: list[str] | tuple[str, ...] = ()) -> None:
44
+ self._graph: nx.DiGraph[Any] = nx.DiGraph()
45
+ self._graph.add_nodes_from(nodes)
46
+
47
+ @property
48
+ def nodes(self) -> tuple[str, ...]:
49
+ """Return graph nodes in insertion order."""
50
+ return tuple(str(node) for node in self._graph.nodes)
51
+
52
+ @property
53
+ def edges(self) -> tuple[CausalEdge, ...]:
54
+ """Return directed edges with metadata."""
55
+ result: list[CausalEdge] = []
56
+ for source, target, attrs in self._graph.edges(data=True):
57
+ result.append(
58
+ CausalEdge(
59
+ source=str(source),
60
+ target=str(target),
61
+ confidence=float(attrs.get("confidence", 0.0)),
62
+ provenance=str(attrs.get("provenance", "unknown")),
63
+ )
64
+ )
65
+ return tuple(result)
66
+
67
+ def add_node(self, column: str) -> None:
68
+ """Add a column node if it is not already present.
69
+
70
+ Args:
71
+ column: Column name.
72
+ """
73
+ self._graph.add_node(column)
74
+
75
+ def add_edge(
76
+ self,
77
+ source: str,
78
+ target: str,
79
+ *,
80
+ confidence: float,
81
+ provenance: str,
82
+ ) -> None:
83
+ """Add a directed causal edge while preserving acyclicity.
84
+
85
+ Args:
86
+ source: Source column name.
87
+ target: Target column name.
88
+ confidence: Confidence score from 0.0 to 1.0.
89
+ provenance: Source of the edge.
90
+
91
+ Raises:
92
+ ValueError: If the edge is self-referential or creates a cycle.
93
+ """
94
+ if source == target:
95
+ raise ValueError("Causal DAG does not allow self-edges")
96
+ self._graph.add_node(source)
97
+ self._graph.add_node(target)
98
+ if nx.has_path(self._graph, target, source):
99
+ raise ValueError(f"Adding {source!r} -> {target!r} would create a cycle")
100
+ bounded = max(0.0, min(1.0, confidence))
101
+ self._graph.add_edge(source, target, confidence=bounded, provenance=provenance)
102
+
103
+ def successors(self, column: str) -> tuple[str, ...]:
104
+ """Return direct downstream columns for a node.
105
+
106
+ Args:
107
+ column: Column name.
108
+
109
+ Returns:
110
+ A tuple of direct successor column names.
111
+ """
112
+ if column not in self._graph:
113
+ return ()
114
+ return tuple(str(node) for node in self._graph.successors(column))
115
+
116
+ def is_reachable(self, source: str, target: str) -> bool:
117
+ """Return whether target is reachable from source.
118
+
119
+ Args:
120
+ source: Source column name.
121
+ target: Target column name.
122
+
123
+ Returns:
124
+ True if source equals target or a directed path exists.
125
+ """
126
+ if source == target:
127
+ return True
128
+ if source not in self._graph or target not in self._graph:
129
+ return False
130
+ return bool(nx.has_path(self._graph, source, target))
131
+
132
+ def path_confidence(self, source: str, target: str) -> float:
133
+ """Return the weakest-edge confidence on the shortest path.
134
+
135
+ Args:
136
+ source: Source column name.
137
+ target: Target column name.
138
+
139
+ Returns:
140
+ Confidence in [0.0, 1.0], or 0.0 when no path exists.
141
+ """
142
+ if source == target:
143
+ return 1.0
144
+ if not self.is_reachable(source, target):
145
+ return 0.0
146
+ path = nx.shortest_path(self._graph, source, target)
147
+ confidences = [
148
+ float(self._graph.edges[path[i], path[i + 1]].get("confidence", 0.0))
149
+ for i in range(len(path) - 1)
150
+ ]
151
+ return min(confidences, default=0.0)
152
+
153
+ def minimal_root_columns(self, columns: list[str] | tuple[str, ...]) -> tuple[str, ...]:
154
+ """Return selected columns that are not downstream of another selection.
155
+
156
+ Args:
157
+ columns: Selected error columns.
158
+
159
+ Returns:
160
+ Minimal root columns in first-seen order.
161
+ """
162
+ unique: list[str] = []
163
+ for column in columns:
164
+ if column not in unique:
165
+ unique.append(column)
166
+
167
+ roots: list[str] = []
168
+ for column in unique:
169
+ has_upstream = any(
170
+ other != column and self.is_reachable(other, column) for other in unique
171
+ )
172
+ if not has_upstream:
173
+ roots.append(column)
174
+ return tuple(roots)
dataforge/causal/pc.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PC-based causal DAG discovery with functional-dependency priors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from scipy.stats import chi2_contingency # type: ignore[import-untyped]
11
+
12
+ from dataforge.causal.dag import CausalDAG
13
+ from dataforge.verifier.schema import Schema
14
+
15
+ __all__ = ["CausalDiscoveryResult", "discover_causal_dag"]
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class CausalDiscoveryResult:
20
+ """Result of causal discovery.
21
+
22
+ Args:
23
+ dag: Directed acyclic graph over columns.
24
+ confidence_report: Column-pair confidence or diagnostic metadata.
25
+ warnings: Non-fatal discovery warnings.
26
+ """
27
+
28
+ dag: CausalDAG
29
+ confidence_report: dict[str, float] = field(default_factory=dict)
30
+ warnings: tuple[str, ...] = ()
31
+
32
+
33
+ def discover_causal_dag(
34
+ df: pd.DataFrame,
35
+ schema: Schema | None = None,
36
+ *,
37
+ alpha: float = 0.05,
38
+ ) -> CausalDiscoveryResult:
39
+ """Infer a deterministic causal DAG from tabular data and FD priors.
40
+
41
+ Args:
42
+ df: Input DataFrame.
43
+ schema: Optional declared schema with functional dependencies.
44
+ alpha: Significance threshold for independence checks.
45
+
46
+ Returns:
47
+ CausalDiscoveryResult. A DAG is returned even if PC orientation is
48
+ underdetermined; low-confidence edges are tagged as such.
49
+ """
50
+ columns = [str(column) for column in df.columns]
51
+ dag = CausalDAG(columns)
52
+ report: dict[str, float] = {}
53
+ warnings: list[str] = []
54
+
55
+ if schema is not None:
56
+ for fd in schema.functional_dependencies:
57
+ for determinant in fd.determinant:
58
+ _try_add_edge(
59
+ dag,
60
+ determinant,
61
+ fd.dependent,
62
+ confidence=0.95,
63
+ provenance="functional_dependency_prior",
64
+ warnings=warnings,
65
+ )
66
+ report[f"{determinant}->{fd.dependent}"] = 0.95
67
+
68
+ cleaned = _prepare_for_pc(df)
69
+ pc_edges, pc_warning = _run_causal_learn_pc(cleaned.to_numpy(), columns, alpha)
70
+ if pc_warning:
71
+ warnings.append(pc_warning)
72
+ for source, target in pc_edges:
73
+ _try_add_edge(
74
+ dag,
75
+ source,
76
+ target,
77
+ confidence=0.55,
78
+ provenance="causal_learn_pc",
79
+ warnings=warnings,
80
+ )
81
+ report.setdefault(f"{source}->{target}", 0.55)
82
+
83
+ for source, target, confidence in _pairwise_dependency_edges(df, alpha):
84
+ _try_add_edge(
85
+ dag,
86
+ source,
87
+ target,
88
+ confidence=confidence,
89
+ provenance="pairwise_ci_fallback",
90
+ warnings=warnings,
91
+ )
92
+ report.setdefault(f"{source}->{target}", confidence)
93
+
94
+ return CausalDiscoveryResult(dag=dag, confidence_report=report, warnings=tuple(warnings))
95
+
96
+
97
+ def _prepare_for_pc(df: pd.DataFrame) -> pd.DataFrame:
98
+ """Return numeric data with no NaN values for causal-learn PC."""
99
+ prepared = pd.DataFrame(index=df.index)
100
+ for column in df.columns:
101
+ numeric = pd.to_numeric(df[column], errors="coerce")
102
+ if numeric.notna().sum() >= max(2, int(0.5 * len(df))):
103
+ fill = float(numeric.median()) if numeric.notna().any() else 0.0
104
+ prepared[str(column)] = numeric.fillna(fill)
105
+ else:
106
+ codes, _ = pd.factorize(df[column].astype("string").fillna("<missing>"), sort=True)
107
+ prepared[str(column)] = codes.astype(float)
108
+ return prepared.fillna(0.0)
109
+
110
+
111
+ def _run_causal_learn_pc(
112
+ data: np.ndarray[Any, Any], columns: list[str], alpha: float
113
+ ) -> tuple[list[tuple[str, str]], str | None]:
114
+ """Run causal-learn PC and return deterministic directed edges."""
115
+ try:
116
+ from causallearn.search.ConstraintBased.PC import pc # type: ignore[import-untyped]
117
+
118
+ result = pc(data, alpha=alpha, indep_test="fisherz", stable=True, show_progress=False)
119
+ except Exception as exc:
120
+ return [], f"causal-learn PC unavailable or failed: {exc}"
121
+
122
+ matrix = getattr(getattr(result, "G", None), "graph", None)
123
+ if matrix is None:
124
+ return [], "causal-learn PC returned no adjacency matrix"
125
+
126
+ edges: list[tuple[str, str]] = []
127
+ arr = np.asarray(matrix)
128
+ for i, source in enumerate(columns):
129
+ for j, target in enumerate(columns):
130
+ if i >= j or i >= arr.shape[0] or j >= arr.shape[1]:
131
+ continue
132
+ if arr[i, j] != 0 or arr[j, i] != 0:
133
+ edges.append((source, target))
134
+ return edges, None
135
+
136
+
137
+ def _pairwise_dependency_edges(df: pd.DataFrame, alpha: float) -> list[tuple[str, str, float]]:
138
+ """Return deterministic low-confidence edges for dependent column pairs."""
139
+ columns = [str(column) for column in df.columns]
140
+ edges: list[tuple[str, str, float]] = []
141
+ for i, source in enumerate(columns):
142
+ for target in columns[i + 1 :]:
143
+ p_value = _pairwise_p_value(df[source], df[target])
144
+ if p_value < alpha:
145
+ confidence = max(0.25, min(0.75, 1.0 - p_value))
146
+ edges.append((source, target, round(confidence, 4)))
147
+ return edges
148
+
149
+
150
+ def _pairwise_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float:
151
+ """Return a p-value using categorical, continuous, or mixed tests."""
152
+ left_numeric = pd.to_numeric(left, errors="coerce")
153
+ right_numeric = pd.to_numeric(right, errors="coerce")
154
+ left_cont = left_numeric.notna().sum() >= max(5, int(0.8 * len(left)))
155
+ right_cont = right_numeric.notna().sum() >= max(5, int(0.8 * len(right)))
156
+
157
+ if left_cont and right_cont:
158
+ return _hsic_p_value(
159
+ left_numeric.fillna(left_numeric.median()), right_numeric.fillna(right_numeric.median())
160
+ )
161
+ if not left_cont and not right_cont:
162
+ return _chi_squared_p_value(left, right)
163
+ return _mutual_information_p_value(left, right)
164
+
165
+
166
+ def _chi_squared_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float:
167
+ """Return chi-squared independence p-value for categorical pairs."""
168
+ table = pd.crosstab(
169
+ left.astype("string").fillna("<missing>"), right.astype("string").fillna("<missing>")
170
+ )
171
+ if table.shape[0] < 2 or table.shape[1] < 2:
172
+ return 1.0
173
+ _, p_value, _, _ = chi2_contingency(table)
174
+ return float(p_value)
175
+
176
+
177
+ def _hsic_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float:
178
+ """Return HSIC p-value for continuous pairs, with correlation fallback."""
179
+ x = left.to_numpy(dtype=float).reshape(-1, 1)
180
+ y = right.to_numpy(dtype=float).reshape(-1, 1)
181
+ try:
182
+ from hyppo.independence import Hsic # type: ignore[import-untyped]
183
+
184
+ _, p_value = Hsic().test(x, y, reps=100, auto=True)
185
+ return float(p_value)
186
+ except Exception:
187
+ corr = abs(float(np.corrcoef(x[:, 0], y[:, 0])[0, 1]))
188
+ return 0.0 if corr > 0.75 else 1.0
189
+
190
+
191
+ def _mutual_information_p_value(left: pd.Series[Any], right: pd.Series[Any]) -> float:
192
+ """Return a bounded pseudo p-value from binned mutual information."""
193
+ left_codes = _codes(left)
194
+ right_codes = _codes(right)
195
+ table = pd.crosstab(left_codes, right_codes)
196
+ total = float(table.to_numpy().sum())
197
+ if total == 0.0 or table.shape[0] < 2 or table.shape[1] < 2:
198
+ return 1.0
199
+ joint = table.to_numpy(dtype=float) / total
200
+ px = joint.sum(axis=1, keepdims=True)
201
+ py = joint.sum(axis=0, keepdims=True)
202
+ expected = px @ py
203
+ mask = joint > 0
204
+ mi = float((joint[mask] * np.log(joint[mask] / expected[mask])).sum())
205
+ return float(np.exp(-mi))
206
+
207
+
208
+ def _codes(series: pd.Series[Any]) -> np.ndarray[Any, Any]:
209
+ """Return stable integer codes for a mixed-type series."""
210
+ numeric = pd.to_numeric(series, errors="coerce")
211
+ if numeric.notna().sum() >= max(5, int(0.8 * len(series))):
212
+ return pd.qcut(
213
+ numeric.fillna(numeric.median()), q=4, duplicates="drop"
214
+ ).cat.codes.to_numpy()
215
+ codes, _ = pd.factorize(series.astype("string").fillna("<missing>"), sort=True)
216
+ return codes
217
+
218
+
219
+ def _try_add_edge(
220
+ dag: CausalDAG,
221
+ source: str,
222
+ target: str,
223
+ *,
224
+ confidence: float,
225
+ provenance: str,
226
+ warnings: list[str],
227
+ ) -> None:
228
+ """Add an edge or record the cycle warning."""
229
+ try:
230
+ dag.add_edge(source, target, confidence=confidence, provenance=provenance)
231
+ except ValueError as exc:
232
+ warnings.append(str(exc))
dataforge/causal/root_cause.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal root-cause selection over detected errors and a causal DAG."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Protocol
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+ from dataforge.causal.dag import CausalDAG
10
+
11
+ __all__ = [
12
+ "CausalRootCauseAnalyzer",
13
+ "ErrorEvidence",
14
+ "RootCauseResult",
15
+ "evidence_from_issue",
16
+ "minimal_root_set",
17
+ ]
18
+
19
+
20
+ class _IssueLike(Protocol):
21
+ """Protocol for objects with row/column issue fields."""
22
+
23
+ row: int
24
+ column: str
25
+ issue_type: str
26
+
27
+
28
+ class ErrorEvidence(BaseModel):
29
+ """Column-mapped detected error used for causal root-cause analysis.
30
+
31
+ Args:
32
+ index: Zero-based error index in the caller's selected issue list.
33
+ row: Row index where the error was detected.
34
+ column: Column where the error was detected.
35
+ issue_type: Machine-readable issue type.
36
+ """
37
+
38
+ index: int = Field(ge=0)
39
+ row: int = Field(ge=0)
40
+ column: str = Field(min_length=1)
41
+ issue_type: str = Field(min_length=1)
42
+
43
+ model_config = {"frozen": True}
44
+
45
+
46
+ class RootCauseResult(BaseModel):
47
+ """Structured result returned by the root-cause analyzer.
48
+
49
+ Args:
50
+ root_indices: Minimal selected error indices.
51
+ root_columns: Root columns corresponding to root_indices.
52
+ covered_indices: Selected error indices covered by the root set.
53
+ confidence: Mean path confidence from roots to covered errors.
54
+ explanation: Human-readable explanation of the selected roots.
55
+ """
56
+
57
+ root_indices: list[int]
58
+ root_columns: list[str]
59
+ covered_indices: list[int]
60
+ confidence: float
61
+ explanation: str
62
+
63
+ model_config = {"frozen": True}
64
+
65
+
66
+ class CausalRootCauseAnalyzer:
67
+ """Compute minimal root causes for selected detected errors.
68
+
69
+ Args:
70
+ dag: Column-level causal DAG.
71
+
72
+ Example:
73
+ >>> dag = CausalDAG(["discount_pct", "order_total"])
74
+ >>> dag.add_edge("discount_pct", "order_total", confidence=0.9, provenance="formula")
75
+ >>> errors = [
76
+ ... ErrorEvidence(index=0, row=1, column="discount_pct", issue_type="bad"),
77
+ ... ErrorEvidence(index=1, row=1, column="order_total", issue_type="bad"),
78
+ ... ]
79
+ >>> CausalRootCauseAnalyzer(dag).analyze(errors).root_indices
80
+ [0]
81
+ """
82
+
83
+ def __init__(self, dag: CausalDAG) -> None:
84
+ self._dag = dag
85
+
86
+ def analyze(self, errors: list[ErrorEvidence] | tuple[ErrorEvidence, ...]) -> RootCauseResult:
87
+ """Return the minimal root set for the selected errors.
88
+
89
+ Args:
90
+ errors: Selected detected errors.
91
+
92
+ Returns:
93
+ RootCauseResult with roots, coverage, confidence, and explanation.
94
+ """
95
+ if not errors:
96
+ return RootCauseResult(
97
+ root_indices=[],
98
+ root_columns=[],
99
+ covered_indices=[],
100
+ confidence=0.0,
101
+ explanation="No errors were supplied.",
102
+ )
103
+
104
+ roots: list[ErrorEvidence] = []
105
+ for candidate in errors:
106
+ if not self._has_upstream_selected_error(candidate, errors):
107
+ roots.append(candidate)
108
+
109
+ covered: list[int] = []
110
+ path_confidences: list[float] = []
111
+ for error in errors:
112
+ for root in roots:
113
+ if root.column == error.column or self._dag.is_reachable(root.column, error.column):
114
+ covered.append(error.index)
115
+ path_confidences.append(self._dag.path_confidence(root.column, error.column))
116
+ break
117
+
118
+ confidence = (
119
+ round(sum(path_confidences) / len(path_confidences), 4) if path_confidences else 0.0
120
+ )
121
+ root_columns = [root.column for root in roots]
122
+ return RootCauseResult(
123
+ root_indices=[root.index for root in roots],
124
+ root_columns=root_columns,
125
+ covered_indices=covered,
126
+ confidence=confidence,
127
+ explanation=self._explain(root_columns, len(covered), len(errors)),
128
+ )
129
+
130
+ def _has_upstream_selected_error(
131
+ self,
132
+ candidate: ErrorEvidence,
133
+ errors: list[ErrorEvidence] | tuple[ErrorEvidence, ...],
134
+ ) -> bool:
135
+ """Return whether another selected error causally precedes candidate."""
136
+ for other in errors:
137
+ if other.index == candidate.index:
138
+ continue
139
+ if other.column == candidate.column and other.index < candidate.index:
140
+ return True
141
+ if other.column != candidate.column and self._dag.is_reachable(
142
+ other.column, candidate.column
143
+ ):
144
+ return True
145
+ return False
146
+
147
+ @staticmethod
148
+ def _explain(root_columns: list[str], covered_count: int, total_count: int) -> str:
149
+ """Build a compact result explanation."""
150
+ if not root_columns:
151
+ return "No minimal roots were found."
152
+ joined = ", ".join(root_columns)
153
+ return f"Selected {joined} as minimal roots covering {covered_count}/{total_count} errors."
154
+
155
+
156
+ def minimal_root_set(
157
+ errors: list[ErrorEvidence] | tuple[ErrorEvidence, ...], dag: CausalDAG
158
+ ) -> RootCauseResult:
159
+ """Convenience wrapper for CausalRootCauseAnalyzer.
160
+
161
+ Args:
162
+ errors: Selected detected errors.
163
+ dag: Column-level causal DAG.
164
+
165
+ Returns:
166
+ Minimal root-cause result.
167
+ """
168
+ return CausalRootCauseAnalyzer(dag).analyze(errors)
169
+
170
+
171
+ def evidence_from_issue(index: int, issue: _IssueLike | dict[str, Any]) -> ErrorEvidence:
172
+ """Build ErrorEvidence from an Issue-like object or dictionary.
173
+
174
+ Args:
175
+ index: Error index to assign.
176
+ issue: Object or dictionary with row/column/type fields.
177
+
178
+ Returns:
179
+ ErrorEvidence instance.
180
+ """
181
+ if isinstance(issue, dict):
182
+ return ErrorEvidence(
183
+ index=index,
184
+ row=int(issue.get("row", 0)),
185
+ column=str(issue.get("column", "")),
186
+ issue_type=str(issue.get("type", issue.get("issue_type", "unknown"))),
187
+ )
188
+ return ErrorEvidence(
189
+ index=index,
190
+ row=int(issue.row),
191
+ column=str(issue.column),
192
+ issue_type=str(issue.issue_type),
193
+ )
dataforge/cli/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """Typer application entrypoint for DataForge.
2
 
3
  Each CLI subcommand is defined in its own module under ``dataforge.cli.*``
4
  and registered here. The ``app`` object is the entry point referenced by
@@ -7,13 +7,16 @@ and registered here. The ``app`` object is the entry point referenced by
7
 
8
  import typer
9
 
 
10
  from dataforge.cli.bench import bench
11
  from dataforge.cli.profile import profile
12
  from dataforge.cli.repair import repair
 
13
  from dataforge.cli.revert import revert
 
14
 
15
  app: typer.Typer = typer.Typer(
16
- help="DataForge AI-powered data-quality detection and repair.",
17
  no_args_is_help=True,
18
  )
19
 
@@ -28,15 +31,18 @@ def _main(
28
  is_eager=True,
29
  ),
30
  ) -> None:
31
- """DataForge AI-powered data-quality detection and repair."""
32
  if version:
33
  from dataforge import __version__
34
 
35
- typer.echo(f"dataforge {__version__}")
36
  raise typer.Exit()
37
 
38
 
39
  app.command(name="profile")(profile)
40
  app.command(name="repair")(repair)
41
  app.command(name="revert")(revert)
 
42
  app.command(name="bench")(bench)
 
 
 
1
+ """Typer application entrypoint for DataForge15.
2
 
3
  Each CLI subcommand is defined in its own module under ``dataforge.cli.*``
4
  and registered here. The ``app`` object is the entry point referenced by
 
7
 
8
  import typer
9
 
10
+ from dataforge.cli.audit import audit
11
  from dataforge.cli.bench import bench
12
  from dataforge.cli.profile import profile
13
  from dataforge.cli.repair import repair
14
+ from dataforge.cli.release import release_app
15
  from dataforge.cli.revert import revert
16
+ from dataforge.cli.watch import watch
17
 
18
  app: typer.Typer = typer.Typer(
19
+ help="DataForge15 - AI-powered data-quality detection and repair.",
20
  no_args_is_help=True,
21
  )
22
 
 
31
  is_eager=True,
32
  ),
33
  ) -> None:
34
+ """DataForge15 - AI-powered data-quality detection and repair."""
35
  if version:
36
  from dataforge import __version__
37
 
38
+ typer.echo(f"dataforge15 {__version__}")
39
  raise typer.Exit()
40
 
41
 
42
  app.command(name="profile")(profile)
43
  app.command(name="repair")(repair)
44
  app.command(name="revert")(revert)
45
+ app.command(name="audit")(audit)
46
  app.command(name="bench")(bench)
47
+ app.command(name="watch")(watch)
48
+ app.add_typer(release_app, name="release")
dataforge/cli/audit.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI subcommand: ``dataforge audit <txn_id>``."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Annotated
8
+
9
+ import typer
10
+ from rich.console import Console
11
+ from rich.panel import Panel
12
+
13
+ from dataforge.transactions import TransactionAuditVerdict, verify_transaction_log
14
+
15
+ _console = Console(stderr=True)
16
+
17
+
18
+ def audit(
19
+ txn_id: Annotated[
20
+ str,
21
+ typer.Argument(help="Transaction identifier to audit."),
22
+ ],
23
+ search_root: Annotated[
24
+ Path | None,
25
+ typer.Option(
26
+ "--search-root",
27
+ help="Root directory used to locate the transaction log.",
28
+ exists=True,
29
+ file_okay=False,
30
+ dir_okay=True,
31
+ readable=True,
32
+ ),
33
+ ] = None,
34
+ log_path: Annotated[
35
+ Path | None,
36
+ typer.Option(
37
+ "--log-path",
38
+ help="Explicit JSONL transaction log path.",
39
+ exists=True,
40
+ file_okay=True,
41
+ dir_okay=False,
42
+ readable=True,
43
+ ),
44
+ ] = None,
45
+ json_output: Annotated[
46
+ bool,
47
+ typer.Option("--json", help="Print the audit report as JSON."),
48
+ ] = False,
49
+ ) -> None:
50
+ """Verify a transaction log's local hash chain."""
51
+ report = verify_transaction_log(txn_id, log_path=log_path, search_root=search_root)
52
+ if json_output:
53
+ typer.echo(json.dumps(report.model_dump(mode="json"), indent=2, sort_keys=True))
54
+ else:
55
+ style = "green" if report.verdict == TransactionAuditVerdict.VERIFIED else "red"
56
+ body = (
57
+ f"Verdict: [bold]{report.verdict.value}[/bold]\n"
58
+ f"Transaction: {report.txn_id or txn_id}\n"
59
+ f"Events: {report.event_count}\n"
60
+ f"Head SHA-256: {report.head_sha256 or 'n/a'}"
61
+ )
62
+ if report.errors:
63
+ body += "\n\n" + "\n".join(f"- {error}" for error in report.errors)
64
+ _console.print(Panel(body, title="Transaction Audit", style=style))
65
+
66
+ if report.verdict == TransactionAuditVerdict.VERIFIED:
67
+ raise typer.Exit(code=0)
68
+ if report.verdict == TransactionAuditVerdict.LEGACY_UNVERIFIED:
69
+ raise typer.Exit(code=1)
70
+ raise typer.Exit(code=2)
dataforge/cli/bench.py CHANGED
@@ -2,17 +2,18 @@
2
 
3
  from __future__ import annotations
4
 
 
 
5
  from pathlib import Path
6
- from typing import Annotated
7
 
8
  import typer
9
  from rich.console import Console
10
  from rich.panel import Panel
11
  from rich.table import Table
12
 
13
- from dataforge.bench.runner import run_agent_comparison
14
-
15
  _console = Console(stderr=True)
 
16
 
17
 
18
  def _parse_csv_list(raw_value: str) -> list[str]:
@@ -21,6 +22,16 @@ def _parse_csv_list(raw_value: str) -> list[str]:
21
  return [value for value in values if value]
22
 
23
 
 
 
 
 
 
 
 
 
 
 
24
  def bench(
25
  methods: Annotated[
26
  str,
@@ -54,10 +65,14 @@ def bench(
54
  help="Where to write eval/results/agent_comparison.json.",
55
  ),
56
  ] = Path("eval/results/agent_comparison.json"),
 
 
 
 
57
  ) -> None:
58
  """Run real-world benchmark methods across cached benchmark datasets."""
59
  try:
60
- output = run_agent_comparison(
61
  methods=_parse_csv_list(methods),
62
  datasets=_parse_csv_list(datasets),
63
  seeds=seeds,
@@ -74,6 +89,10 @@ def bench(
74
  )
75
  raise typer.Exit(code=2) from exc
76
 
 
 
 
 
77
  table = Table(title="DataForge Benchmark Summary")
78
  table.add_column("Method")
79
  table.add_column("Dataset")
 
2
 
3
  from __future__ import annotations
4
 
5
+ import json
6
+ from collections.abc import Callable
7
  from pathlib import Path
8
+ from typing import Annotated, Any
9
 
10
  import typer
11
  from rich.console import Console
12
  from rich.panel import Panel
13
  from rich.table import Table
14
 
 
 
15
  _console = Console(stderr=True)
16
+ run_agent_comparison: Callable[..., Any] | None = None
17
 
18
 
19
  def _parse_csv_list(raw_value: str) -> list[str]:
 
22
  return [value for value in values if value]
23
 
24
 
25
+ def _runner() -> Callable[..., Any]:
26
+ """Load the benchmark runner lazily so core CLI imports stay lightweight."""
27
+ global run_agent_comparison
28
+ if run_agent_comparison is None:
29
+ from dataforge.bench.runner import run_agent_comparison as loaded_runner
30
+
31
+ run_agent_comparison = loaded_runner
32
+ return run_agent_comparison
33
+
34
+
35
  def bench(
36
  methods: Annotated[
37
  str,
 
65
  help="Where to write eval/results/agent_comparison.json.",
66
  ),
67
  ] = Path("eval/results/agent_comparison.json"),
68
+ json_output: Annotated[
69
+ bool,
70
+ typer.Option("--json", help="Print benchmark results as JSON."),
71
+ ] = False,
72
  ) -> None:
73
  """Run real-world benchmark methods across cached benchmark datasets."""
74
  try:
75
+ output = _runner()(
76
  methods=_parse_csv_list(methods),
77
  datasets=_parse_csv_list(datasets),
78
  seeds=seeds,
 
89
  )
90
  raise typer.Exit(code=2) from exc
91
 
92
+ if json_output:
93
+ typer.echo(json.dumps(output.model_dump(mode="json"), indent=2, sort_keys=True))
94
+ return
95
+
96
  table = Table(title="DataForge Benchmark Summary")
97
  table.add_column("Method")
98
  table.add_column("Dataset")
dataforge/cli/common.py CHANGED
@@ -3,13 +3,14 @@
3
  from __future__ import annotations
4
 
5
  from collections.abc import Iterable
 
6
  from pathlib import Path
7
  from typing import cast
8
 
9
- import pandas as pd
10
  import typer
11
  import yaml
12
 
 
13
  from dataforge.verifier.schema import (
14
  AggregateDependency,
15
  AggregateLiteral,
@@ -18,6 +19,27 @@ from dataforge.verifier.schema import (
18
  Schema,
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def schema_from_mapping(raw_mapping: object) -> Schema:
23
  """Build a Schema from a raw YAML mapping-like payload.
@@ -149,13 +171,13 @@ def load_schema(schema_path: Path) -> Schema:
149
  return schema_from_mapping(raw)
150
 
151
 
152
- def read_csv(path: Path) -> pd.DataFrame:
153
  """Read a CSV using conservative string-preserving defaults.
154
 
155
  Args:
156
  path: CSV path.
157
 
158
  Returns:
159
- A DataFrame with string-preserved values.
160
  """
161
- return pd.read_csv(path, dtype=str, keep_default_na=False, na_filter=False)
 
3
  from __future__ import annotations
4
 
5
  from collections.abc import Iterable
6
+ from importlib import resources
7
  from pathlib import Path
8
  from typing import cast
9
 
 
10
  import typer
11
  import yaml
12
 
13
+ from dataforge.table import Table, read_csv as read_table_csv
14
  from dataforge.verifier.schema import (
15
  AggregateDependency,
16
  AggregateLiteral,
 
19
  Schema,
20
  )
21
 
22
+ _PACKAGED_DEMO_FIXTURES = {
23
+ "fixtures/hospital_10rows.csv": "fixtures/hospital_10rows.csv",
24
+ "fixtures/hospital_schema.yaml": "fixtures/hospital_schema.yaml",
25
+ }
26
+
27
+
28
+ def resolve_cli_path(path: Path) -> Path:
29
+ """Resolve a user path, including DataForge's packaged demo fixture aliases."""
30
+ if path.exists():
31
+ return path
32
+
33
+ normalized = path.as_posix().replace("\\", "/").lstrip("./")
34
+ packaged_name = _PACKAGED_DEMO_FIXTURES.get(normalized)
35
+ if packaged_name is None:
36
+ return path
37
+
38
+ fixture = resources.files("dataforge").joinpath(packaged_name)
39
+ if not fixture.is_file():
40
+ return path
41
+ return Path(str(fixture))
42
+
43
 
44
  def schema_from_mapping(raw_mapping: object) -> Schema:
45
  """Build a Schema from a raw YAML mapping-like payload.
 
171
  return schema_from_mapping(raw)
172
 
173
 
174
+ def read_csv(path: Path) -> Table:
175
  """Read a CSV using conservative string-preserving defaults.
176
 
177
  Args:
178
  path: CSV path.
179
 
180
  Returns:
181
+ A string-preserving DataForge table.
182
  """
183
+ return read_table_csv(path)
dataforge/cli/profile.py CHANGED
@@ -1,31 +1,46 @@
1
  """CLI subcommand: ``dataforge profile <path> [--schema <yaml>]``.
2
 
3
  Reads a CSV file, runs all detectors, and renders detected issues as a
4
- rich-formatted terminal table. Exit code 0 if no UNSAFE issues; 1 otherwise.
 
5
  """
6
 
7
  from __future__ import annotations
8
 
 
 
9
  from pathlib import Path
10
- from typing import Annotated
11
 
12
  import typer
13
  from rich.console import Console
14
 
15
- from dataforge.cli.common import load_schema, read_csv
16
  from dataforge.detectors import run_all_detectors
17
- from dataforge.detectors.base import Schema, Severity
18
  from dataforge.ui.profile_view import render_profile_table
19
 
20
  _console = Console(stderr=True)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def profile(
24
  path: Annotated[
25
  Path,
26
  typer.Argument(
27
- exists=True,
28
- readable=True,
29
  help="Path to the CSV file to profile.",
30
  ),
31
  ],
@@ -33,22 +48,36 @@ def profile(
33
  Path | None,
34
  typer.Option(
35
  "--schema",
36
- exists=True,
37
- readable=True,
38
  help="Path to a YAML schema file with column types and FDs.",
39
  ),
40
  ] = None,
 
 
 
 
 
 
 
 
 
 
 
41
  ) -> None:
42
  """Profile a CSV file for data-quality issues.
43
 
44
  Reads the CSV, runs all detectors (type_mismatch, decimal_shift,
45
  fd_violation), and renders a rich-formatted table of detected issues.
46
 
47
- Exit code 0 if no UNSAFE issues are found; 1 if any UNSAFE issues exist.
48
  """
 
 
 
 
 
49
  # Load the CSV with dtype=str to avoid pandas type-coercion artifacts.
50
  try:
51
- df = read_csv(path)
52
  except Exception as exc:
53
  _console.print(f"[bold red]Error reading CSV:[/bold red] {exc}")
54
  raise typer.Exit(code=2) from exc
@@ -56,16 +85,32 @@ def profile(
56
  # Optionally load schema.
57
  parsed_schema: Schema | None = None
58
  if schema is not None:
59
- parsed_schema = load_schema(schema)
 
 
 
 
60
 
61
  # Run all detectors.
62
  issues = run_all_detectors(df, parsed_schema)
63
 
64
  # Render the results.
65
- output_console = Console()
66
- render_profile_table(issues, output_console, file_path=str(path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # Exit code based on UNSAFE issues.
69
- has_unsafe = any(i.severity == Severity.UNSAFE for i in issues)
70
- if has_unsafe:
71
  raise typer.Exit(code=1)
 
1
  """CLI subcommand: ``dataforge profile <path> [--schema <yaml>]``.
2
 
3
  Reads a CSV file, runs all detectors, and renders detected issues as a
4
+ rich-formatted terminal table. Diagnostics exit 0 by default; use
5
+ ``--fail-on`` for CI gating.
6
  """
7
 
8
  from __future__ import annotations
9
 
10
+ import json
11
+ from collections.abc import Sequence
12
  from pathlib import Path
13
+ from typing import Annotated, Literal
14
 
15
  import typer
16
  from rich.console import Console
17
 
18
+ from dataforge.cli.common import load_schema, read_csv, resolve_cli_path
19
  from dataforge.detectors import run_all_detectors
20
+ from dataforge.detectors.base import Issue, Schema, Severity
21
  from dataforge.ui.profile_view import render_profile_table
22
 
23
  _console = Console(stderr=True)
24
 
25
+ FailOn = Literal["never", "unsafe", "review", "any"]
26
+
27
+
28
+ def _should_fail(issues: Sequence[Issue], fail_on: FailOn) -> bool:
29
+ """Return whether profile findings should trip the requested CI gate."""
30
+ if fail_on == "never":
31
+ return False
32
+ if fail_on == "any":
33
+ return bool(issues)
34
+ severities = [issue.severity for issue in issues]
35
+ if fail_on == "unsafe":
36
+ return any(severity == Severity.UNSAFE for severity in severities)
37
+ return any(severity >= Severity.REVIEW for severity in severities)
38
+
39
 
40
  def profile(
41
  path: Annotated[
42
  Path,
43
  typer.Argument(
 
 
44
  help="Path to the CSV file to profile.",
45
  ),
46
  ],
 
48
  Path | None,
49
  typer.Option(
50
  "--schema",
 
 
51
  help="Path to a YAML schema file with column types and FDs.",
52
  ),
53
  ] = None,
54
+ json_output: Annotated[
55
+ bool,
56
+ typer.Option("--json", help="Print profile results as JSON."),
57
+ ] = False,
58
+ fail_on: Annotated[
59
+ FailOn,
60
+ typer.Option(
61
+ "--fail-on",
62
+ help="Exit 1 when findings meet this threshold: never, unsafe, review, any.",
63
+ ),
64
+ ] = "never",
65
  ) -> None:
66
  """Profile a CSV file for data-quality issues.
67
 
68
  Reads the CSV, runs all detectors (type_mismatch, decimal_shift,
69
  fd_violation), and renders a rich-formatted table of detected issues.
70
 
71
+ Exit code 0 unless ``--fail-on`` is set and matching findings are present.
72
  """
73
+ resolved_path = resolve_cli_path(path)
74
+ if not resolved_path.exists():
75
+ _console.print(f"[bold red]CSV file not found:[/bold red] {path}")
76
+ raise typer.Exit(code=2)
77
+
78
  # Load the CSV with dtype=str to avoid pandas type-coercion artifacts.
79
  try:
80
+ df = read_csv(resolved_path)
81
  except Exception as exc:
82
  _console.print(f"[bold red]Error reading CSV:[/bold red] {exc}")
83
  raise typer.Exit(code=2) from exc
 
85
  # Optionally load schema.
86
  parsed_schema: Schema | None = None
87
  if schema is not None:
88
+ resolved_schema = resolve_cli_path(schema)
89
+ if not resolved_schema.exists():
90
+ _console.print(f"[bold red]Schema file not found:[/bold red] {schema}")
91
+ raise typer.Exit(code=2)
92
+ parsed_schema = load_schema(resolved_schema)
93
 
94
  # Run all detectors.
95
  issues = run_all_detectors(df, parsed_schema)
96
 
97
  # Render the results.
98
+ if json_output:
99
+ typer.echo(
100
+ json.dumps(
101
+ {
102
+ "path": str(resolved_path),
103
+ "issues_count": len(issues),
104
+ "fail_on": fail_on,
105
+ "issues": [issue.model_dump(mode="json") for issue in issues],
106
+ },
107
+ indent=2,
108
+ sort_keys=True,
109
+ )
110
+ )
111
+ else:
112
+ output_console = Console()
113
+ render_profile_table(issues, output_console, file_path=str(resolved_path))
114
 
115
+ if _should_fail(issues, fail_on):
 
 
116
  raise typer.Exit(code=1)
dataforge/cli/release.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI group for local release verification."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Annotated
8
+
9
+ import typer
10
+
11
+ from dataforge.release.doctor import DEFAULT_KAGGLE_CREDENTIALS, run_doctor
12
+
13
+ release_app = typer.Typer(help="Release verification utilities.", no_args_is_help=True)
14
+
15
+
16
+ @release_app.command(name="doctor")
17
+ def doctor(
18
+ json_output: Annotated[
19
+ bool,
20
+ typer.Option("--json", help="Print machine-readable JSON."),
21
+ ] = False,
22
+ kaggle_credentials: Annotated[
23
+ Path,
24
+ typer.Option(
25
+ "--kaggle-credentials",
26
+ help="Path to Kaggle OAuth credentials.json. Legacy kaggle.json is never read.",
27
+ ),
28
+ ] = DEFAULT_KAGGLE_CREDENTIALS,
29
+ ) -> None:
30
+ """Verify local release/deploy auth without printing secrets."""
31
+ report = run_doctor(kaggle_credentials=kaggle_credentials)
32
+ if json_output:
33
+ typer.echo(json.dumps(report.to_dict(), indent=2, sort_keys=True))
34
+ else:
35
+ for check in report.checks:
36
+ status = "ok" if check.ok else "fail"
37
+ typer.echo(f"{status:4} {check.name}: {check.detail}")
38
+ raise typer.Exit(code=0 if report.ok else 2)
39
+
dataforge/cli/repair.py CHANGED
@@ -2,32 +2,25 @@
2
 
3
  from __future__ import annotations
4
 
5
- import hashlib
6
- from datetime import UTC, datetime
7
  from pathlib import Path
8
- from typing import Annotated
9
 
10
- import pandas as pd
11
  import typer
12
  from rich.console import Console
13
  from rich.panel import Panel
14
 
15
- from dataforge.cli.common import load_schema, read_csv
16
- from dataforge.detectors import run_all_detectors
17
  from dataforge.detectors.base import Issue, Schema
18
- from dataforge.repairers import build_repairers
19
- from dataforge.repairers.base import ProposedFix, RepairAttempt, RetryContext
20
- from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict
21
- from dataforge.transactions.log import (
22
- append_applied_event,
23
- append_created_transaction,
24
- cache_dir_for,
25
- sha256_bytes,
26
- snapshot_path_for,
27
- )
28
- from dataforge.transactions.txn import CellFix, RepairTransaction, generate_txn_id
29
  from dataforge.ui.repair_diff import render_repair_diff
30
- from dataforge.verifier import SMTVerifier, VerificationVerdict
 
 
 
 
31
 
32
  _console = Console(stderr=True)
33
 
@@ -45,32 +38,19 @@ def apply_fixes_to_csv(path: Path, fixes: list[CellFix]) -> str:
45
  Raises:
46
  ValueError: If a fix references a missing row/column or stale old value.
47
  """
48
- df = read_csv(path)
49
- for fix in fixes:
50
- if fix.operation != "update":
51
- raise ValueError(f"Unsupported repair operation '{fix.operation}' for row {fix.row}.")
52
- if fix.column not in df.columns:
53
- raise ValueError(f"Column '{fix.column}' not found in '{path}'.")
54
- if fix.row < 0 or fix.row >= len(df.index):
55
- raise ValueError(f"Row {fix.row} is out of bounds for '{path}'.")
56
-
57
- current_value = str(df.at[fix.row, fix.column])
58
- if current_value != fix.old_value:
59
- raise ValueError(
60
- f"Refusing to apply stale fix for row {fix.row}, column '{fix.column}': "
61
- f"expected '{fix.old_value}', found '{current_value}'."
62
- )
63
- df.at[fix.row, fix.column] = fix.new_value
64
 
65
- df.to_csv(path, index=False, lineterminator="\n")
66
- return hashlib.sha256(path.read_bytes()).hexdigest()
67
 
68
 
69
  def _resolve_schema(schema_path: Path | None) -> Schema | None:
70
  """Resolve an optional schema path into a parsed Schema."""
71
  if schema_path is None:
72
  return None
73
- return load_schema(schema_path)
 
 
 
74
 
75
 
76
  def _print_error(message: str, *, hint: str | None = None) -> None:
@@ -94,157 +74,21 @@ def _propose_repairs(
94
  confirm_escalations: bool,
95
  interactive: bool,
96
  ) -> tuple[list[ProposedFix], list[list[RepairAttempt]]]:
97
- """Run repairers and gates issue-by-issue against the working dataframe."""
98
- repairers = build_repairers(
99
- cache_dir=cache_dir_for(path),
 
 
 
 
 
100
  allow_llm=allow_llm,
101
  model=model,
102
- )
103
- safety_filter = SafetyFilter()
104
- verifier = SMTVerifier()
105
- safety_context = SafetyContext(
106
  allow_pii=allow_pii,
107
  confirm_pii=confirm_pii,
108
  confirm_escalations=confirm_escalations,
109
- )
110
-
111
- accepted_fixes: list[ProposedFix] = []
112
- attempt_groups: list[list[RepairAttempt]] = []
113
-
114
- for issue in issues:
115
- attempts: list[RepairAttempt] = []
116
- repairer = repairers.get(issue.issue_type)
117
- if repairer is None:
118
- attempts.append(
119
- RepairAttempt(
120
- issue=issue,
121
- attempt_number=1,
122
- status="attempted_not_fixed",
123
- reason="No repairer is registered for this issue type.",
124
- )
125
- )
126
- attempt_groups.append(attempts)
127
- continue
128
-
129
- accepted = False
130
- retry_context = RetryContext(issue=issue)
131
- for attempt_number in range(1, 4):
132
- candidate = repairer.propose(issue, working_df, schema, retry_context=retry_context)
133
- if candidate is None:
134
- attempts.append(
135
- RepairAttempt(
136
- issue=issue,
137
- attempt_number=attempt_number,
138
- status="attempted_not_fixed",
139
- reason="No repair proposal was available for this issue.",
140
- )
141
- )
142
- break
143
-
144
- preferred = safety_filter.choose_preferred([candidate], schema, safety_context)
145
- safety_result = safety_filter.evaluate(preferred, schema, safety_context)
146
- if safety_result.verdict == SafetyVerdict.ESCALATE and interactive:
147
- safety_context, safety_result = _resolve_escalation(
148
- preferred,
149
- schema,
150
- safety_context,
151
- safety_filter,
152
- safety_result,
153
- )
154
-
155
- if safety_result.verdict == SafetyVerdict.DENY:
156
- attempts.append(
157
- RepairAttempt(
158
- issue=issue,
159
- attempt_number=attempt_number,
160
- fix=preferred,
161
- status="denied",
162
- reason=safety_result.reason,
163
- )
164
- )
165
- retry_context = _build_retry_context(issue, attempts)
166
- continue
167
-
168
- if safety_result.verdict == SafetyVerdict.ESCALATE:
169
- attempts.append(
170
- RepairAttempt(
171
- issue=issue,
172
- attempt_number=attempt_number,
173
- fix=preferred,
174
- status="escalated",
175
- reason=safety_result.reason,
176
- )
177
- )
178
- break
179
-
180
- verifier_result = verifier.verify(working_df, [preferred], schema)
181
- if verifier_result.verdict == VerificationVerdict.ACCEPT:
182
- accepted_fixes.append(preferred)
183
- working_df.at[preferred.fix.row, preferred.fix.column] = preferred.fix.new_value
184
- attempts.append(
185
- RepairAttempt(
186
- issue=issue,
187
- attempt_number=attempt_number,
188
- fix=preferred,
189
- status="accepted",
190
- reason=verifier_result.reason,
191
- )
192
- )
193
- accepted = True
194
- break
195
-
196
- attempts.append(
197
- RepairAttempt(
198
- issue=issue,
199
- attempt_number=attempt_number,
200
- fix=preferred,
201
- status=(
202
- "rejected"
203
- if verifier_result.verdict == VerificationVerdict.REJECT
204
- else "unknown"
205
- ),
206
- reason=verifier_result.reason,
207
- unsat_core=verifier_result.unsat_core,
208
- )
209
- )
210
- retry_context = _build_retry_context(issue, attempts)
211
-
212
- if (
213
- not accepted
214
- and attempts
215
- and attempts[-1].status not in {"attempted_not_fixed", "escalated"}
216
- ):
217
- last_reason = attempts[-1].reason
218
- attempts[-1] = attempts[-1].model_copy(
219
- update={
220
- "status": "attempted_not_fixed",
221
- "reason": (
222
- f"Issue was attempted but not fixed after {len(attempts)} attempt(s). "
223
- f"Last failure: {last_reason}"
224
- ),
225
- }
226
- )
227
- attempt_groups.append(attempts)
228
-
229
- return accepted_fixes, attempt_groups
230
-
231
-
232
- def _build_retry_context(issue: Issue, attempts: list[RepairAttempt]) -> RetryContext:
233
- """Build retry hints from previous failed attempts."""
234
- rejected_values = frozenset(
235
- attempt.fix.fix.new_value
236
- for attempt in attempts
237
- if attempt.fix is not None and attempt.status in {"denied", "rejected", "unknown"}
238
- )
239
- hints: list[str] = []
240
- for attempt in attempts:
241
- hints.append(attempt.reason)
242
- hints.extend(attempt.unsat_core)
243
- return RetryContext(
244
- issue=issue,
245
- previous_attempts=tuple(attempts),
246
- rejected_values=rejected_values,
247
- hints=tuple(hints),
248
  )
249
 
250
 
@@ -309,45 +153,46 @@ def _render_attempt_summary(
309
  return len(failed_groups)
310
 
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  def _apply_transaction(
313
  path: Path,
314
  fixes: list[ProposedFix],
315
  source_bytes: bytes,
316
  ) -> str:
317
- """Write a transaction record, apply fixes, and append the applied event."""
318
- resolved_path = path.resolve()
319
- txn_id = generate_txn_id()
320
- snapshot_path = snapshot_path_for(resolved_path, txn_id)
321
- snapshot_path.parent.mkdir(parents=True, exist_ok=True)
322
- snapshot_path.write_bytes(source_bytes)
323
-
324
- transaction = RepairTransaction(
325
- txn_id=txn_id,
326
- created_at=datetime.now(UTC),
327
- source_path=str(resolved_path),
328
- source_sha256=sha256_bytes(source_bytes),
329
- source_snapshot_path=str(snapshot_path.resolve()),
330
- fixes=[proposal.fix for proposal in fixes],
331
- applied=False,
332
- )
333
- log_path = append_created_transaction(transaction)
334
 
335
- try:
336
- post_sha256 = apply_fixes_to_csv(path, [proposal.fix for proposal in fixes])
337
- append_applied_event(log_path, txn_id, post_sha256=post_sha256)
338
- except Exception:
339
- path.write_bytes(source_bytes)
340
- raise
341
-
342
- return txn_id
343
 
344
 
345
  def repair(
346
  path: Annotated[
347
  Path,
348
  typer.Argument(
349
- exists=True,
350
- readable=True,
351
  help="Path to the CSV file to repair.",
352
  ),
353
  ],
@@ -355,8 +200,6 @@ def repair(
355
  Path | None,
356
  typer.Option(
357
  "--schema",
358
- exists=True,
359
- readable=True,
360
  help="Path to a YAML schema file with column types and FDs.",
361
  ),
362
  ] = None,
@@ -400,6 +243,10 @@ def repair(
400
  str,
401
  typer.Option("--llm-model", help="Model name for fd_violation LLM fallback."),
402
  ] = "gemini-2.0-flash",
 
 
 
 
403
  ) -> None:
404
  """Detect, propose, and optionally apply reversible repairs to a CSV."""
405
  if dry_run == apply:
@@ -410,58 +257,66 @@ def repair(
410
  raise typer.Exit(code=2)
411
 
412
  try:
 
 
 
413
  parsed_schema = _resolve_schema(schema)
414
- df = read_csv(path)
415
  except Exception as exc:
416
  _print_error(str(exc))
417
  raise typer.Exit(code=2) from exc
418
 
419
- issues = run_all_detectors(df, parsed_schema)
420
- accepted_fixes, attempt_groups = _propose_repairs(
421
- issues,
422
- path,
423
- df.copy(deep=True),
424
- parsed_schema,
425
- allow_llm=allow_llm,
426
- model=llm_model,
427
- allow_pii=allow_pii,
428
- confirm_pii=confirm_pii,
429
- confirm_escalations=confirm_escalations,
430
- interactive=apply,
431
- )
 
 
 
 
 
 
 
 
 
432
 
433
- output_console = Console()
434
- render_repair_diff(accepted_fixes, output_console, file_path=str(path))
435
- failed_issue_count = _render_attempt_summary(attempt_groups, output_console)
436
 
437
- if not accepted_fixes and failed_issue_count == 0:
 
 
 
 
 
 
 
 
 
 
 
 
438
  raise typer.Exit(code=1)
439
 
440
  if dry_run:
441
- raise typer.Exit(code=0 if accepted_fixes else 1)
442
-
443
- if not accepted_fixes:
444
- raise typer.Exit(code=1)
445
 
446
- batch_safety = SafetyFilter().evaluate_batch(accepted_fixes)
447
- if batch_safety.verdict != SafetyVerdict.ALLOW:
448
- _print_error(batch_safety.reason)
449
  raise typer.Exit(code=1)
450
 
451
- source_bytes = path.read_bytes()
452
- try:
453
- txn_id = _apply_transaction(path, accepted_fixes, source_bytes)
454
- except Exception as exc:
455
- _print_error(
456
- f"Failed to apply repairs: {exc}",
457
- hint="The source file was restored to its pre-apply bytes.",
458
- )
459
- raise typer.Exit(code=1) from exc
460
-
461
  output_console.print(
462
  Panel(
463
- f"[green]Applied {len(accepted_fixes)} fix(es).[/green]\n"
464
- f"Transaction ID: [bold]{txn_id}[/bold]",
465
  title="Repair Applied",
466
  style="green",
467
  )
 
2
 
3
  from __future__ import annotations
4
 
5
+ import json
 
6
  from pathlib import Path
7
+ from typing import TYPE_CHECKING, Annotated
8
 
 
9
  import typer
10
  from rich.console import Console
11
  from rich.panel import Panel
12
 
13
+ from dataforge.cli.common import load_schema, resolve_cli_path
 
14
  from dataforge.detectors.base import Issue, Schema
15
+ from dataforge.repairers.base import ProposedFix, RepairAttempt
16
+ from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult
17
+ from dataforge.transactions.txn import CellFix
 
 
 
 
 
 
 
 
18
  from dataforge.ui.repair_diff import render_repair_diff
19
+
20
+ if TYPE_CHECKING:
21
+ import pandas as pd
22
+
23
+ from dataforge.engine.repair import RepairPipelineResult
24
 
25
  _console = Console(stderr=True)
26
 
 
38
  Raises:
39
  ValueError: If a fix references a missing row/column or stale old value.
40
  """
41
+ from dataforge.engine.repair import apply_fixes_to_csv as engine_apply_fixes_to_csv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ return engine_apply_fixes_to_csv(path, fixes)
 
44
 
45
 
46
  def _resolve_schema(schema_path: Path | None) -> Schema | None:
47
  """Resolve an optional schema path into a parsed Schema."""
48
  if schema_path is None:
49
  return None
50
+ resolved_schema = resolve_cli_path(schema_path)
51
+ if not resolved_schema.exists():
52
+ raise typer.BadParameter(f"Schema file '{schema_path}' does not exist.")
53
+ return load_schema(resolved_schema)
54
 
55
 
56
  def _print_error(message: str, *, hint: str | None = None) -> None:
 
74
  confirm_escalations: bool,
75
  interactive: bool,
76
  ) -> tuple[list[ProposedFix], list[list[RepairAttempt]]]:
77
+ """Compatibility wrapper around the shared repair engine proposal stage."""
78
+ from dataforge.engine.repair import propose_repairs as engine_propose_repairs
79
+
80
+ return engine_propose_repairs(
81
+ issues,
82
+ path,
83
+ working_df,
84
+ schema,
85
  allow_llm=allow_llm,
86
  model=model,
 
 
 
 
87
  allow_pii=allow_pii,
88
  confirm_pii=confirm_pii,
89
  confirm_escalations=confirm_escalations,
90
+ interactive=interactive,
91
+ escalation_resolver=_resolve_escalation,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
 
94
 
 
153
  return len(failed_groups)
154
 
155
 
156
+ def _render_failure_summary(result: RepairPipelineResult, console: Console) -> int:
157
+ """Render a summary for issues that the shared engine could not repair."""
158
+ if not result.failures:
159
+ return 0
160
+
161
+ console.print("[bold yellow]Attempted But Not Fixed[/bold yellow]")
162
+ for failure in result.failures:
163
+ prefix = ""
164
+ if any(label.startswith("fd::") for label in failure.unsat_core):
165
+ prefix = "functional dependency rejection - "
166
+ elif any(label.startswith("domain::") for label in failure.unsat_core):
167
+ prefix = "domain bound rejection - "
168
+ console.print(
169
+ f"{failure.issue_type} at {failure.row}:{failure.column} "
170
+ f"after {failure.attempt_count} attempt(s): {prefix}{failure.reason}",
171
+ overflow="fold",
172
+ )
173
+ return len(result.failures)
174
+
175
+
176
+ def _json_result(result: RepairPipelineResult) -> str:
177
+ """Serialize a repair result for CLI/MCP/CI consumers."""
178
+ return json.dumps(result.model_dump(mode="json"), indent=2, sort_keys=True)
179
+
180
+
181
  def _apply_transaction(
182
  path: Path,
183
  fixes: list[ProposedFix],
184
  source_bytes: bytes,
185
  ) -> str:
186
+ """Compatibility wrapper around the shared repair engine transaction path."""
187
+ from dataforge.engine.repair import apply_transaction as engine_apply_transaction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ return engine_apply_transaction(path, fixes, source_bytes)
 
 
 
 
 
 
 
190
 
191
 
192
  def repair(
193
  path: Annotated[
194
  Path,
195
  typer.Argument(
 
 
196
  help="Path to the CSV file to repair.",
197
  ),
198
  ],
 
200
  Path | None,
201
  typer.Option(
202
  "--schema",
 
 
203
  help="Path to a YAML schema file with column types and FDs.",
204
  ),
205
  ] = None,
 
243
  str,
244
  typer.Option("--llm-model", help="Model name for fd_violation LLM fallback."),
245
  ] = "gemini-2.0-flash",
246
+ json_output: Annotated[
247
+ bool,
248
+ typer.Option("--json", help="Print repair result as JSON."),
249
+ ] = False,
250
  ) -> None:
251
  """Detect, propose, and optionally apply reversible repairs to a CSV."""
252
  if dry_run == apply:
 
257
  raise typer.Exit(code=2)
258
 
259
  try:
260
+ resolved_path = resolve_cli_path(path)
261
+ if not resolved_path.exists():
262
+ raise typer.BadParameter(f"CSV file '{path}' does not exist.")
263
  parsed_schema = _resolve_schema(schema)
 
264
  except Exception as exc:
265
  _print_error(str(exc))
266
  raise typer.Exit(code=2) from exc
267
 
268
+ try:
269
+ from dataforge.engine.repair import RepairPipelineRequest, run_repair_pipeline
270
+
271
+ result = run_repair_pipeline(
272
+ RepairPipelineRequest(
273
+ source_path=resolved_path,
274
+ mode="apply" if apply else "dry_run",
275
+ schema=parsed_schema,
276
+ allow_llm=allow_llm,
277
+ model=llm_model,
278
+ allow_pii=allow_pii,
279
+ confirm_pii=confirm_pii,
280
+ confirm_escalations=confirm_escalations,
281
+ interactive=apply,
282
+ )
283
+ )
284
+ except Exception as exc:
285
+ _print_error(
286
+ f"Failed to apply repairs: {exc}" if apply else f"Failed to repair: {exc}",
287
+ hint="The source file was restored to its pre-apply bytes." if apply else None,
288
+ )
289
+ raise typer.Exit(code=1 if apply else 2) from exc
290
 
291
+ if json_output:
292
+ typer.echo(_json_result(result))
293
+ raise typer.Exit(code=0 if result.fixes else 1)
294
 
295
+ output_console = Console()
296
+ render_repair_diff(result.fixes, output_console, file_path=str(resolved_path))
297
+ failed_issue_count = _render_failure_summary(result, output_console)
298
+
299
+ if not result.fixes and failed_issue_count == 0:
300
+ if result.receipt.reason != "No accepted fixes were produced.":
301
+ output_console.print(
302
+ Panel(
303
+ f"[yellow]{result.receipt.reason}[/yellow]",
304
+ title="Repair Summary",
305
+ style="yellow",
306
+ )
307
+ )
308
  raise typer.Exit(code=1)
309
 
310
  if dry_run:
311
+ raise typer.Exit(code=0 if result.fixes else 1)
 
 
 
312
 
313
+ if not result.fixes or not result.receipt.applied:
 
 
314
  raise typer.Exit(code=1)
315
 
 
 
 
 
 
 
 
 
 
 
316
  output_console.print(
317
  Panel(
318
+ f"[green]Applied {len(result.fixes)} fix(es).[/green]\n"
319
+ f"Transaction ID: [bold]{result.receipt.txn_id}[/bold]",
320
  title="Repair Applied",
321
  style="green",
322
  )
dataforge/cli/watch.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI subcommand: ``dataforge watch``."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Annotated, Literal
9
+
10
+ import typer
11
+ from rich.console import Console
12
+ from rich.panel import Panel
13
+
14
+ from dataforge.cli.common import load_schema, read_csv, resolve_cli_path
15
+ from dataforge.detectors import run_all_detectors
16
+ from dataforge.detectors.base import Schema
17
+ from dataforge.ui.profile_view import render_profile_table
18
+ from dataforge.ui.repair_diff import render_repair_diff
19
+
20
+ _console = Console(stderr=True)
21
+
22
+ WatchAction = Literal["profile", "repair"]
23
+
24
+
25
+ def _load_optional_schema(schema_path: Path | None) -> Schema | None:
26
+ if schema_path is None:
27
+ return None
28
+ resolved_schema = resolve_cli_path(schema_path)
29
+ if not resolved_schema.exists():
30
+ raise typer.BadParameter(f"Schema file '{schema_path}' does not exist.")
31
+ return load_schema(resolved_schema)
32
+
33
+
34
+ def _profile_once(path: Path, schema: Schema | None, json_output: bool) -> None:
35
+ df = read_csv(path)
36
+ issues = run_all_detectors(df, schema)
37
+ if json_output:
38
+ typer.echo(
39
+ json.dumps(
40
+ {
41
+ "event": "profile",
42
+ "path": str(path),
43
+ "issues_count": len(issues),
44
+ "issues": [issue.model_dump(mode="json") for issue in issues],
45
+ },
46
+ indent=2,
47
+ sort_keys=True,
48
+ )
49
+ )
50
+ return
51
+ render_profile_table(issues, Console(), file_path=str(path))
52
+
53
+
54
+ def _repair_once(path: Path, schema: Schema | None, apply: bool, json_output: bool) -> None:
55
+ from dataforge.engine.repair import RepairPipelineRequest, run_repair_pipeline
56
+
57
+ result = run_repair_pipeline(
58
+ RepairPipelineRequest(
59
+ source_path=path,
60
+ mode="apply" if apply else "dry_run",
61
+ schema=schema,
62
+ interactive=False,
63
+ )
64
+ )
65
+ if json_output:
66
+ payload = result.model_dump(mode="json")
67
+ payload["event"] = "repair"
68
+ typer.echo(json.dumps(payload, indent=2, sort_keys=True))
69
+ return
70
+ render_repair_diff(result.fixes, Console(), file_path=str(path))
71
+
72
+
73
+ def _run_once(path: Path, schema: Schema | None, action: WatchAction, apply: bool, json: bool) -> None:
74
+ if action == "repair":
75
+ _repair_once(path, schema, apply, json)
76
+ else:
77
+ _profile_once(path, schema, json)
78
+
79
+
80
+ def watch(
81
+ path: Annotated[
82
+ Path,
83
+ typer.Argument(help="CSV or dbt artifact path to watch."),
84
+ ],
85
+ schema: Annotated[
86
+ Path | None,
87
+ typer.Option("--schema", help="Path to a YAML schema file with column types and FDs."),
88
+ ] = None,
89
+ action: Annotated[
90
+ WatchAction,
91
+ typer.Option("--action", help="Action to run when the file changes: profile or repair."),
92
+ ] = "profile",
93
+ apply: Annotated[
94
+ bool,
95
+ typer.Option("--apply", help="Apply repairs on change. Defaults to dry-run repair."),
96
+ ] = False,
97
+ interval: Annotated[
98
+ float,
99
+ typer.Option("--interval", min=0.1, help="Polling interval in seconds."),
100
+ ] = 2.0,
101
+ once: Annotated[
102
+ bool,
103
+ typer.Option("--once", help="Run once and exit, useful for CI acceptance."),
104
+ ] = False,
105
+ json_output: Annotated[
106
+ bool,
107
+ typer.Option("--json", help="Print watch events as JSON."),
108
+ ] = False,
109
+ ) -> None:
110
+ """Poll a path and rerun profile or repair when it changes."""
111
+ resolved_path = resolve_cli_path(path)
112
+ if not resolved_path.exists():
113
+ _console.print(f"[bold red]Watch path not found:[/bold red] {path}")
114
+ raise typer.Exit(code=2)
115
+ parsed_schema = _load_optional_schema(schema)
116
+
117
+ if apply and action != "repair":
118
+ _console.print(
119
+ Panel(
120
+ "--apply is only valid with --action repair.",
121
+ title="Watch Error",
122
+ style="red",
123
+ )
124
+ )
125
+ raise typer.Exit(code=2)
126
+
127
+ _run_once(resolved_path, parsed_schema, action, apply, json_output)
128
+ if once:
129
+ return
130
+
131
+ last_mtime = resolved_path.stat().st_mtime_ns
132
+ while True:
133
+ time.sleep(interval)
134
+ try:
135
+ current_mtime = resolved_path.stat().st_mtime_ns
136
+ except FileNotFoundError:
137
+ _console.print(f"[bold red]Watch path disappeared:[/bold red] {resolved_path}")
138
+ raise typer.Exit(code=2) from None
139
+ if current_mtime == last_mtime:
140
+ continue
141
+ last_mtime = current_mtime
142
+ _run_once(resolved_path, parsed_schema, action, apply, json_output)
dataforge/datasets/embedded/hospital/clean.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,age,admission_date,name
2
+ 1,30,2020-01-01,Alice
3
+ 2,45,2020-01-02,Bob
4
+ 3,30,2020-01-03,Carol
5
+ 4,29,2020-01-04,Dave
6
+ 5,35,2020-01-05,Eve
7
+ 6,51,2020-01-06,Frank
8
+ 7,40,2020-01-07,Grace
9
+ 8,35,2020-01-08,Heidi
10
+ 9,28,2020-01-09,Ivan
11
+ 10,60,2020-01-10,Judy
dataforge/datasets/embedded/hospital/dirty.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,age,admission_date,name
2
+ 1,30,2020-01-01,Alice
3
+ 2,45,2020-01-02,Bob
4
+ 3,N/A,2020-01-03,Carol
5
+ 4,29,2020-01-04,Dave
6
+ 5,null,2020-01-05,Eve
7
+ 6,51,2020-01-06,Frank
8
+ 7,40,2020-01-07,Grace
9
+ 8,35,2020-01-08,Heidi
10
+ 9,28,2020-01-09,Ivan
11
+ 10,60,2020-01-10,Judy
dataforge/datasets/real_world.py CHANGED
@@ -2,6 +2,8 @@
2
 
3
  from __future__ import annotations
4
 
 
 
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
 
@@ -16,6 +18,9 @@ class DatasetDownloadError(RuntimeError):
16
  """Raised when a real-world dataset cannot be downloaded or loaded from cache."""
17
 
18
 
 
 
 
19
  class GroundTruthCell(BaseModel):
20
  """Single cell-level dirty-to-clean correction used for benchmark scoring."""
21
 
@@ -57,7 +62,11 @@ def _read_cached_csv(path: Path) -> pd.DataFrame:
57
 
58
  def _download_bytes(url: str) -> bytes:
59
  """Download raw CSV bytes from an upstream source URL."""
60
- with httpx.Client(timeout=60.0, follow_redirects=True) as client:
 
 
 
 
61
  response = client.get(url)
62
  response.raise_for_status()
63
  return response.content
@@ -67,8 +76,19 @@ def _download_to_cache(metadata: DatasetMetadata, dataset_dir: Path) -> None:
67
  """Download dirty/clean CSV files into the dataset cache directory."""
68
  dataset_dir.mkdir(parents=True, exist_ok=True)
69
  dirty_url, clean_url = metadata.source_urls
 
70
  (dataset_dir / "dirty.csv").write_bytes(_download_bytes(dirty_url))
71
  (dataset_dir / "clean.csv").write_bytes(_download_bytes(clean_url))
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  def _manual_download_message(metadata: DatasetMetadata, dataset_dir: Path, cause: Exception) -> str:
@@ -153,16 +173,26 @@ def load_real_world_dataset(
153
  dirty_path = dataset_dir / "dirty.csv"
154
  clean_path = dataset_dir / "clean.csv"
155
 
 
 
 
156
  if not dirty_path.exists() or not clean_path.exists():
 
157
  try:
158
  _download_to_cache(metadata, dataset_dir)
159
  except Exception as exc: # pragma: no cover - exercised through tests via monkeypatch
160
- raise DatasetDownloadError(
161
- _manual_download_message(metadata, dataset_dir, exc)
162
- ) from exc
163
-
164
- dirty_df = _read_cached_csv(dirty_path)
165
- clean_df = _read_cached_csv(clean_path)
 
 
 
 
 
 
166
 
167
  if len(dirty_df.index) != len(clean_df.index):
168
  raise ValueError(f"Dataset '{name}' dirty/clean row counts do not match.")
 
2
 
3
  from __future__ import annotations
4
 
5
+ import logging
6
+ import os
7
  from dataclasses import dataclass
8
  from pathlib import Path
9
 
 
18
  """Raised when a real-world dataset cannot be downloaded or loaded from cache."""
19
 
20
 
21
+ _LOGGER = logging.getLogger("dataforge.datasets.real_world")
22
+
23
+
24
  class GroundTruthCell(BaseModel):
25
  """Single cell-level dirty-to-clean correction used for benchmark scoring."""
26
 
 
62
 
63
  def _download_bytes(url: str) -> bytes:
64
  """Download raw CSV bytes from an upstream source URL."""
65
+ try:
66
+ timeout = float(os.environ.get("DATAFORGE_DOWNLOAD_TIMEOUT_S", "5"))
67
+ except ValueError:
68
+ timeout = 5.0
69
+ with httpx.Client(timeout=timeout, follow_redirects=True) as client:
70
  response = client.get(url)
71
  response.raise_for_status()
72
  return response.content
 
76
  """Download dirty/clean CSV files into the dataset cache directory."""
77
  dataset_dir.mkdir(parents=True, exist_ok=True)
78
  dirty_url, clean_url = metadata.source_urls
79
+ _LOGGER.info("dataset_download_start name=%s dir=%s", metadata.name, dataset_dir)
80
  (dataset_dir / "dirty.csv").write_bytes(_download_bytes(dirty_url))
81
  (dataset_dir / "clean.csv").write_bytes(_download_bytes(clean_url))
82
+ _LOGGER.info("dataset_download_complete name=%s dir=%s", metadata.name, dataset_dir)
83
+
84
+
85
+ def _load_embedded_dataset(name: str) -> tuple[pd.DataFrame, pd.DataFrame] | None:
86
+ root = Path(__file__).parent / "embedded" / name
87
+ dirty_path = root / "dirty.csv"
88
+ clean_path = root / "clean.csv"
89
+ if not dirty_path.exists() or not clean_path.exists():
90
+ return None
91
+ return _read_cached_csv(dirty_path), _read_cached_csv(clean_path)
92
 
93
 
94
  def _manual_download_message(metadata: DatasetMetadata, dataset_dir: Path, cause: Exception) -> str:
 
173
  dirty_path = dataset_dir / "dirty.csv"
174
  clean_path = dataset_dir / "clean.csv"
175
 
176
+ dirty_df: pd.DataFrame | None = None
177
+ clean_df: pd.DataFrame | None = None
178
+
179
  if not dirty_path.exists() or not clean_path.exists():
180
+ _LOGGER.info("dataset_cache_miss name=%s dir=%s", name, dataset_dir)
181
  try:
182
  _download_to_cache(metadata, dataset_dir)
183
  except Exception as exc: # pragma: no cover - exercised through tests via monkeypatch
184
+ fallback = _load_embedded_dataset(name)
185
+ if fallback is None:
186
+ raise DatasetDownloadError(
187
+ _manual_download_message(metadata, dataset_dir, exc)
188
+ ) from exc
189
+ dirty_df, clean_df = fallback
190
+ else:
191
+ _LOGGER.info("dataset_cache_hit name=%s dir=%s", name, dataset_dir)
192
+
193
+ if dirty_df is None or clean_df is None:
194
+ dirty_df = _read_cached_csv(dirty_path)
195
+ clean_df = _read_cached_csv(clean_path)
196
 
197
  if len(dirty_df.index) != len(clean_df.index):
198
  raise ValueError(f"Dataset '{name}' dirty/clean row counts do not match.")
dataforge/detectors/__init__.py CHANGED
@@ -12,8 +12,6 @@ deduplicated, severity-sorted issue list.
12
 
13
  from __future__ import annotations
14
 
15
- import pandas as pd
16
-
17
  from dataforge.detectors.base import Detector, Issue, Schema, Severity
18
  from dataforge.detectors.decimal_shift import DecimalShiftDetector
19
  from dataforge.detectors.fd_violation import FDViolationDetector
@@ -33,14 +31,14 @@ __all__ = [
33
  _SEVERITY_ORDER = {Severity.UNSAFE: 0, Severity.REVIEW: 1, Severity.SAFE: 2}
34
 
35
 
36
- def run_all_detectors(df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]:
37
  """Run all registered detectors and return a merged, sorted issue list.
38
 
39
  Issues are deduplicated by (row, column, issue_type) and sorted by
40
  severity (UNSAFE first) then confidence (highest first).
41
 
42
  Args:
43
- df: The input DataFrame to analyze.
44
  schema: Optional declared schema with column types and constraints.
45
 
46
  Returns:
 
12
 
13
  from __future__ import annotations
14
 
 
 
15
  from dataforge.detectors.base import Detector, Issue, Schema, Severity
16
  from dataforge.detectors.decimal_shift import DecimalShiftDetector
17
  from dataforge.detectors.fd_violation import FDViolationDetector
 
31
  _SEVERITY_ORDER = {Severity.UNSAFE: 0, Severity.REVIEW: 1, Severity.SAFE: 2}
32
 
33
 
34
+ def run_all_detectors(df: object, schema: Schema | None = None) -> list[Issue]:
35
  """Run all registered detectors and return a merged, sorted issue list.
36
 
37
  Issues are deduplicated by (row, column, issue_type) and sorted by
38
  severity (UNSAFE first) then confidence (highest first).
39
 
40
  Args:
41
+ df: The input table to analyze.
42
  schema: Optional declared schema with column types and constraints.
43
 
44
  Returns:
dataforge/detectors/base.py CHANGED
@@ -5,9 +5,9 @@ from __future__ import annotations
5
  import enum
6
  from typing import Literal, Protocol
7
 
8
- import pandas as pd
9
  from pydantic import BaseModel, Field
10
 
 
11
  from dataforge.verifier.schema import (
12
  AggregateDependency,
13
  DomainBound,
@@ -114,23 +114,23 @@ class Issue(BaseModel):
114
  class Detector(Protocol):
115
  """Structural protocol that every detector must implement.
116
 
117
- A detector is a pure function over tabular data: it receives a DataFrame
118
  and an optional Schema, and returns a list of Issue objects. No LLM calls,
119
  no disk I/O, no side effects.
120
 
121
  Example:
122
  >>> class MyDetector:
123
  ... def detect(
124
- ... self, df: pd.DataFrame, schema: Schema | None = None
125
  ... ) -> list[Issue]:
126
  ... return []
127
  """
128
 
129
- def detect(self, df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]:
130
  """Detect data-quality issues in the given DataFrame.
131
 
132
  Args:
133
- df: The input DataFrame to analyze.
134
  schema: Optional declared schema with column types and constraints.
135
 
136
  Returns:
 
5
  import enum
6
  from typing import Literal, Protocol
7
 
 
8
  from pydantic import BaseModel, Field
9
 
10
+ from dataforge.table import TableLike
11
  from dataforge.verifier.schema import (
12
  AggregateDependency,
13
  DomainBound,
 
114
  class Detector(Protocol):
115
  """Structural protocol that every detector must implement.
116
 
117
+ A detector is a pure function over tabular data: it receives a table
118
  and an optional Schema, and returns a list of Issue objects. No LLM calls,
119
  no disk I/O, no side effects.
120
 
121
  Example:
122
  >>> class MyDetector:
123
  ... def detect(
124
+ ... self, df: TableLike, schema: Schema | None = None
125
  ... ) -> list[Issue]:
126
  ... return []
127
  """
128
 
129
+ def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]:
130
  """Detect data-quality issues in the given DataFrame.
131
 
132
  Args:
133
+ df: The input table to analyze.
134
  schema: Optional declared schema with column types and constraints.
135
 
136
  Returns:
dataforge/detectors/decimal_shift.py CHANGED
@@ -10,15 +10,10 @@ The detector is **pure**: no LLM calls, no I/O, no side effects.
10
  from __future__ import annotations
11
 
12
  import math
13
- from typing import TYPE_CHECKING
14
-
15
- import numpy as np
16
- import pandas as pd
17
 
18
  from dataforge.detectors.base import Issue, Schema, Severity
19
-
20
- if TYPE_CHECKING:
21
- pass
22
 
23
  # Minimum non-null numeric values required for meaningful statistics.
24
  _MIN_COLUMN_SIZE = 5
@@ -70,7 +65,7 @@ class DecimalShiftDetector:
70
  3
71
  """
72
 
73
- def detect(self, df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]:
74
  """Detect decimal-shift issues in the DataFrame.
75
 
76
  Args:
@@ -83,13 +78,13 @@ class DecimalShiftDetector:
83
  """
84
  issues: list[Issue] = []
85
 
86
- for col_name in df.columns:
87
  col_issues = self._check_column(df, str(col_name))
88
  issues.extend(col_issues)
89
 
90
  return issues
91
 
92
- def _check_column(self, df: pd.DataFrame, col_name: str) -> list[Issue]:
93
  """Check a single column for decimal-shift outliers.
94
 
95
  Args:
@@ -101,7 +96,7 @@ class DecimalShiftDetector:
101
  """
102
  # Parse all values to float, keeping track of original indices.
103
  parsed: list[tuple[int, float, str]] = []
104
- for row_idx, val in enumerate(df[col_name].tolist()):
105
  fval = _try_float(val)
106
  if fval is not None:
107
  parsed.append((row_idx, fval, str(val)))
@@ -109,11 +104,10 @@ class DecimalShiftDetector:
109
  if len(parsed) < _MIN_COLUMN_SIZE:
110
  return []
111
 
112
- values = np.array([v for _, v, _ in parsed])
113
- median = float(np.median(values))
114
 
115
  # If median is zero or very close, we cannot compute meaningful ratios.
116
- if abs(median) < 1e-10:
117
  return []
118
 
119
  issues: list[Issue] = []
@@ -121,7 +115,7 @@ class DecimalShiftDetector:
121
  if abs(fval) < 1e-10:
122
  continue
123
 
124
- ratio = fval / median
125
  if abs(ratio) < 1e-10:
126
  continue
127
 
@@ -147,13 +141,13 @@ class DecimalShiftDetector:
147
  reason = (
148
  f"Value {fval:g} in column '{col_name}' appears to be "
149
  f"~{int(correction_factor)}x the typical value "
150
- f"(median ~{median:g})"
151
  )
152
  else:
153
  reason = (
154
  f"Value {fval:g} in column '{col_name}' appears to be "
155
  f"~{1.0 / correction_factor:g}x too small compared to "
156
- f"the typical value (median ~{median:g})"
157
  )
158
 
159
  issues.append(
 
10
  from __future__ import annotations
11
 
12
  import math
13
+ from statistics import median
 
 
 
14
 
15
  from dataforge.detectors.base import Issue, Schema, Severity
16
+ from dataforge.table import TableLike, column_names, column_values
 
 
17
 
18
  # Minimum non-null numeric values required for meaningful statistics.
19
  _MIN_COLUMN_SIZE = 5
 
65
  3
66
  """
67
 
68
+ def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]:
69
  """Detect decimal-shift issues in the DataFrame.
70
 
71
  Args:
 
78
  """
79
  issues: list[Issue] = []
80
 
81
+ for col_name in column_names(df):
82
  col_issues = self._check_column(df, str(col_name))
83
  issues.extend(col_issues)
84
 
85
  return issues
86
 
87
+ def _check_column(self, df: TableLike, col_name: str) -> list[Issue]:
88
  """Check a single column for decimal-shift outliers.
89
 
90
  Args:
 
96
  """
97
  # Parse all values to float, keeping track of original indices.
98
  parsed: list[tuple[int, float, str]] = []
99
+ for row_idx, val in enumerate(column_values(df, col_name)):
100
  fval = _try_float(val)
101
  if fval is not None:
102
  parsed.append((row_idx, fval, str(val)))
 
104
  if len(parsed) < _MIN_COLUMN_SIZE:
105
  return []
106
 
107
+ center = float(median([v for _, v, _ in parsed]))
 
108
 
109
  # If median is zero or very close, we cannot compute meaningful ratios.
110
+ if abs(center) < 1e-10:
111
  return []
112
 
113
  issues: list[Issue] = []
 
115
  if abs(fval) < 1e-10:
116
  continue
117
 
118
+ ratio = fval / center
119
  if abs(ratio) < 1e-10:
120
  continue
121
 
 
141
  reason = (
142
  f"Value {fval:g} in column '{col_name}' appears to be "
143
  f"~{int(correction_factor)}x the typical value "
144
+ f"(median ~{center:g})"
145
  )
146
  else:
147
  reason = (
148
  f"Value {fval:g} in column '{col_name}' appears to be "
149
  f"~{1.0 / correction_factor:g}x too small compared to "
150
+ f"the typical value (median ~{center:g})"
151
  )
152
 
153
  issues.append(
dataforge/detectors/fd_violation.py CHANGED
@@ -12,14 +12,8 @@ The detector is **pure**: no LLM calls, no I/O, no side effects.
12
 
13
  from __future__ import annotations
14
 
15
- from typing import TYPE_CHECKING
16
-
17
- import pandas as pd
18
-
19
  from dataforge.detectors.base import Issue, Schema, Severity
20
-
21
- if TYPE_CHECKING:
22
- pass
23
 
24
 
25
  class FDViolationDetector:
@@ -49,7 +43,7 @@ class FDViolationDetector:
49
  2
50
  """
51
 
52
- def detect(self, df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]:
53
  """Detect FD-violation issues in the DataFrame.
54
 
55
  Args:
@@ -73,7 +67,7 @@ class FDViolationDetector:
73
 
74
  def _check_fd(
75
  self,
76
- df: pd.DataFrame,
77
  determinant: tuple[str, ...],
78
  dependent: str,
79
  ) -> list[Issue]:
@@ -91,34 +85,37 @@ class FDViolationDetector:
91
 
92
  # Verify all columns exist in the DataFrame.
93
  all_cols = [*determinant_columns, dependent]
 
94
  for col in all_cols:
95
- if col not in df.columns:
96
  return []
97
 
98
- # Drop rows with null values in determinant columns.
99
- subset = df[all_cols].copy()
100
- mask = subset[determinant_columns].notna().all(axis=1)
101
- subset = subset[mask]
 
 
102
 
103
- if subset.empty:
104
  return []
105
 
106
- # Group by determinant and find groups with multiple distinct
107
- # dependent values.
108
  issues: list[Issue] = []
109
-
110
- grouped = subset.groupby(determinant_columns, sort=False)
111
- for group_key, group_df in grouped:
112
- unique_deps = group_df[dependent].dropna().unique()
 
 
 
113
  if len(unique_deps) <= 1:
114
  continue
115
 
116
- # All rows in this group are part of the violation.
117
  det_desc = self._format_determinant(determinant, group_key)
118
  unique_str = ", ".join(repr(str(v)) for v in unique_deps)
119
 
120
- for idx in group_df.index:
121
- actual_val = str(group_df.at[idx, dependent])
122
  reason = (
123
  f"Functional dependency {determinant} -> {dependent} "
124
  f"violated: {det_desc} maps to multiple values: "
 
12
 
13
  from __future__ import annotations
14
 
 
 
 
 
15
  from dataforge.detectors.base import Issue, Schema, Severity
16
+ from dataforge.table import TableLike, cell_value, column_names, row_count
 
 
17
 
18
 
19
  class FDViolationDetector:
 
43
  2
44
  """
45
 
46
+ def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]:
47
  """Detect FD-violation issues in the DataFrame.
48
 
49
  Args:
 
67
 
68
  def _check_fd(
69
  self,
70
+ df: TableLike,
71
  determinant: tuple[str, ...],
72
  dependent: str,
73
  ) -> list[Issue]:
 
85
 
86
  # Verify all columns exist in the DataFrame.
87
  all_cols = [*determinant_columns, dependent]
88
+ available_columns = set(column_names(df))
89
  for col in all_cols:
90
+ if col not in available_columns:
91
  return []
92
 
93
+ groups: dict[tuple[str, ...], list[int]] = {}
94
+ for row in range(row_count(df)):
95
+ group_key = tuple(cell_value(df, row, column) for column in determinant_columns)
96
+ if any(value == "" for value in group_key):
97
+ continue
98
+ groups.setdefault(group_key, []).append(row)
99
 
100
+ if not groups:
101
  return []
102
 
 
 
103
  issues: list[Issue] = []
104
+ for group_key, row_indices in groups.items():
105
+ unique_deps: list[str] = []
106
+ for row in row_indices:
107
+ value = cell_value(df, row, dependent)
108
+ if value == "" or value in unique_deps:
109
+ continue
110
+ unique_deps.append(value)
111
  if len(unique_deps) <= 1:
112
  continue
113
 
 
114
  det_desc = self._format_determinant(determinant, group_key)
115
  unique_str = ", ".join(repr(str(v)) for v in unique_deps)
116
 
117
+ for idx in row_indices:
118
+ actual_val = cell_value(df, idx, dependent)
119
  reason = (
120
  f"Functional dependency {determinant} -> {dependent} "
121
  f"violated: {det_desc} maps to multiple values: "
dataforge/detectors/type_mismatch.py CHANGED
@@ -10,14 +10,9 @@ The detector is **pure**: no LLM calls, no I/O, no side effects.
10
  from __future__ import annotations
11
 
12
  import re
13
- from typing import TYPE_CHECKING
14
-
15
- import pandas as pd
16
 
17
  from dataforge.detectors.base import Issue, Schema, Severity
18
-
19
- if TYPE_CHECKING:
20
- pass
21
 
22
  # Compiled regexes for type inference.
23
  _NUMERIC_RE = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?$")
@@ -69,7 +64,7 @@ class TypeMismatchDetector:
69
  'N/A'
70
  """
71
 
72
- def detect(self, df: pd.DataFrame, schema: Schema | None = None) -> list[Issue]:
73
  """Detect type-mismatch issues in the DataFrame.
74
 
75
  Args:
@@ -84,13 +79,13 @@ class TypeMismatchDetector:
84
  """
85
  issues: list[Issue] = []
86
 
87
- for col_name in df.columns:
88
  col_issues = self._check_column(df, str(col_name))
89
  issues.extend(col_issues)
90
 
91
  return issues
92
 
93
- def _check_column(self, df: pd.DataFrame, col_name: str) -> list[Issue]:
94
  """Check a single column for type mismatches.
95
 
96
  Args:
@@ -100,12 +95,10 @@ class TypeMismatchDetector:
100
  Returns:
101
  Issues found in this column.
102
  """
103
- series = df[col_name]
104
-
105
  # Collect (index, value, type) for non-null entries.
106
  classified: list[tuple[int, str, str]] = []
107
- for row_idx, val in enumerate(series.tolist()):
108
- if pd.isna(val):
109
  continue
110
  str_val = str(val).strip()
111
  if not str_val:
 
10
  from __future__ import annotations
11
 
12
  import re
 
 
 
13
 
14
  from dataforge.detectors.base import Issue, Schema, Severity
15
+ from dataforge.table import TableLike, column_names, column_values
 
 
16
 
17
  # Compiled regexes for type inference.
18
  _NUMERIC_RE = re.compile(r"^[+-]?(\d+\.?\d*|\.\d+)([eE][+-]?\d+)?$")
 
64
  'N/A'
65
  """
66
 
67
+ def detect(self, df: TableLike, schema: Schema | None = None) -> list[Issue]:
68
  """Detect type-mismatch issues in the DataFrame.
69
 
70
  Args:
 
79
  """
80
  issues: list[Issue] = []
81
 
82
+ for col_name in column_names(df):
83
  col_issues = self._check_column(df, str(col_name))
84
  issues.extend(col_issues)
85
 
86
  return issues
87
 
88
+ def _check_column(self, df: TableLike, col_name: str) -> list[Issue]:
89
  """Check a single column for type mismatches.
90
 
91
  Args:
 
95
  Returns:
96
  Issues found in this column.
97
  """
 
 
98
  # Collect (index, value, type) for non-null entries.
99
  classified: list[tuple[int, str, str]] = []
100
+ for row_idx, val in enumerate(column_values(df, col_name)):
101
+ if val is None:
102
  continue
103
  str_val = str(val).strip()
104
  if not str_val:
dataforge/engine/__init__.py CHANGED
@@ -1 +1,33 @@
1
- """Engine package scaffolding for DataForge."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public backend engine APIs for DataForge."""
2
+
3
+ from dataforge.engine.repair import (
4
+ CandidateFix,
5
+ RepairFailure,
6
+ RepairMode,
7
+ RepairPipelineRequest,
8
+ RepairPipelineResult,
9
+ RepairReceipt,
10
+ VerifiedFix,
11
+ apply_fixes_to_csv,
12
+ apply_transaction,
13
+ create_repair_transaction,
14
+ propose_repairs,
15
+ run_repair_pipeline,
16
+ source_path_lock,
17
+ )
18
+
19
+ __all__ = [
20
+ "CandidateFix",
21
+ "RepairFailure",
22
+ "RepairMode",
23
+ "RepairPipelineRequest",
24
+ "RepairPipelineResult",
25
+ "RepairReceipt",
26
+ "VerifiedFix",
27
+ "apply_fixes_to_csv",
28
+ "apply_transaction",
29
+ "create_repair_transaction",
30
+ "propose_repairs",
31
+ "run_repair_pipeline",
32
+ "source_path_lock",
33
+ ]
dataforge/engine/repair.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public repair engine for DataForge backend surfaces.
2
+
3
+ The engine is the stable boundary shared by CLI, Playground, MCP, and any
4
+ OpenEnv adapter that needs repair semantics. It keeps the core invariant in one
5
+ place: detect -> propose -> safety -> SMT verification -> journal/snapshot ->
6
+ atomic mutation -> byte-identical revert.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import hashlib
12
+ import os
13
+ import secrets
14
+ import time
15
+ from collections.abc import Callable, Iterator
16
+ from contextlib import contextmanager, suppress
17
+ from datetime import UTC, datetime
18
+ from pathlib import Path
19
+ from typing import Literal
20
+
21
+ from pydantic import BaseModel, ConfigDict, Field
22
+
23
+ from dataforge.detectors import run_all_detectors
24
+ from dataforge.detectors.base import Issue, Schema
25
+ from dataforge.observability import repair_stage_span
26
+ from dataforge.repair_contract import CONTRACT_VERSION
27
+ from dataforge.repairers import build_repairers
28
+ from dataforge.repairers.base import ProposedFix, RepairAttempt, RetryContext
29
+ from dataforge.safety import SafetyContext, SafetyFilter, SafetyResult, SafetyVerdict
30
+ from dataforge.table import (
31
+ Table,
32
+ TableLike,
33
+ cell_value,
34
+ column_names,
35
+ copy_table,
36
+ row_count,
37
+ set_cell_value,
38
+ table_to_csv_bytes,
39
+ )
40
+ from dataforge.table import (
41
+ read_csv as read_table_csv,
42
+ )
43
+ from dataforge.transactions.log import (
44
+ append_applied_event,
45
+ append_created_transaction,
46
+ cache_dir_for,
47
+ sha256_bytes,
48
+ sha256_file,
49
+ snapshot_path_for,
50
+ )
51
+ from dataforge.transactions.txn import CellFix, RepairTransaction, generate_txn_id
52
+ from dataforge.verifier import SMTVerifier, VerificationVerdict
53
+
54
+ RepairMode = Literal["dry_run", "apply"]
55
+ EscalationResolver = Callable[
56
+ [ProposedFix, Schema | None, SafetyContext, SafetyFilter, SafetyResult],
57
+ tuple[SafetyContext, SafetyResult],
58
+ ]
59
+
60
+
61
+ class RepairEngineError(RuntimeError):
62
+ """Base exception for public repair engine failures."""
63
+
64
+
65
+ class TransactionApplyError(RepairEngineError):
66
+ """Raised when an apply transaction cannot be completed safely."""
67
+
68
+
69
+ class CandidateFix(BaseModel):
70
+ """Stable public representation of a proposed cell repair."""
71
+
72
+ row: int = Field(ge=0)
73
+ column: str = Field(min_length=1)
74
+ old_value: str
75
+ new_value: str
76
+ detector_id: str = Field(min_length=1)
77
+ operation: Literal["update", "delete_row"] = "update"
78
+ reason: str = Field(min_length=1)
79
+ confidence: float = Field(ge=0.0, le=1.0)
80
+ provenance: str = Field(min_length=1)
81
+
82
+ model_config = ConfigDict(strict=True, extra="forbid", frozen=True)
83
+
84
+ @classmethod
85
+ def from_proposed(cls, proposed_fix: ProposedFix) -> CandidateFix:
86
+ """Create a public candidate from an internal repair proposal."""
87
+ fix = proposed_fix.fix
88
+ return cls(
89
+ row=fix.row,
90
+ column=fix.column,
91
+ old_value=fix.old_value,
92
+ new_value=fix.new_value,
93
+ detector_id=fix.detector_id,
94
+ operation=fix.operation,
95
+ reason=proposed_fix.reason,
96
+ confidence=proposed_fix.confidence,
97
+ provenance=proposed_fix.provenance,
98
+ )
99
+
100
+
101
+ class VerifiedFix(CandidateFix):
102
+ """A candidate that passed safety and SMT verification."""
103
+
104
+ verifier_reason: str = Field(min_length=1)
105
+
106
+
107
+ class RepairFailure(BaseModel):
108
+ """Machine-readable account of an issue that could not be repaired."""
109
+
110
+ row: int = Field(ge=0)
111
+ column: str = Field(min_length=1)
112
+ issue_type: str = Field(min_length=1)
113
+ status: str = Field(min_length=1)
114
+ reason: str = Field(min_length=1)
115
+ attempt_count: int = Field(ge=1)
116
+ unsat_core: tuple[str, ...] = Field(default_factory=tuple)
117
+
118
+ model_config = ConfigDict(strict=True, extra="forbid", frozen=True)
119
+
120
+ @classmethod
121
+ def from_attempts(cls, attempts: list[RepairAttempt]) -> RepairFailure:
122
+ """Build a public failure record from one issue's attempt trace."""
123
+ final = attempts[-1]
124
+ issue = final.issue
125
+ return cls(
126
+ row=issue.row,
127
+ column=issue.column,
128
+ issue_type=issue.issue_type,
129
+ status=final.status,
130
+ reason=final.reason,
131
+ attempt_count=len(attempts),
132
+ unsat_core=tuple(final.unsat_core),
133
+ )
134
+
135
+
136
+ class RepairReceipt(BaseModel):
137
+ """Stable receipt for a dry-run or applied repair pipeline run."""
138
+
139
+ contract_version: str = CONTRACT_VERSION
140
+ mode: RepairMode
141
+ applied: bool
142
+ reversible: bool
143
+ source_path: str
144
+ source_sha256: str = Field(pattern=r"^[0-9a-f]{64}$")
145
+ post_sha256: str | None = Field(default=None, pattern=r"^[0-9a-f]{64}$")
146
+ txn_id: str | None = None
147
+ allowed_columns: list[str] = Field(default_factory=list)
148
+ valid_rows: list[int] = Field(default_factory=list)
149
+ issues_count: int = Field(ge=0)
150
+ fixes_count: int = Field(ge=0)
151
+ reason: str = Field(min_length=1)
152
+
153
+ model_config = ConfigDict(strict=True, extra="forbid", frozen=True)
154
+
155
+
156
+ class RepairPipelineRequest(BaseModel):
157
+ """Input contract for running the public repair pipeline."""
158
+
159
+ source_path: Path
160
+ mode: RepairMode = "dry_run"
161
+ repair_schema: Schema | None = Field(default=None, alias="schema")
162
+ allow_llm: bool = False
163
+ model: str = "gemini-2.0-flash"
164
+ allow_pii: bool = False
165
+ confirm_pii: bool = False
166
+ confirm_escalations: bool = False
167
+ interactive: bool = False
168
+ create_dry_run_transaction: bool = False
169
+
170
+ model_config = ConfigDict(
171
+ strict=True,
172
+ arbitrary_types_allowed=True,
173
+ extra="forbid",
174
+ frozen=True,
175
+ populate_by_name=True,
176
+ )
177
+
178
+
179
+ class RepairPipelineResult(BaseModel):
180
+ """Output contract for a public repair pipeline run."""
181
+
182
+ receipt: RepairReceipt
183
+ issues: list[Issue]
184
+ fixes: list[VerifiedFix]
185
+ failures: list[RepairFailure] = Field(default_factory=list)
186
+ transaction: RepairTransaction | None = None
187
+
188
+ model_config = ConfigDict(
189
+ strict=True, arbitrary_types_allowed=True, extra="forbid", frozen=True
190
+ )
191
+
192
+
193
+ def _atomic_write_bytes(path: Path, payload: bytes) -> None:
194
+ """Write bytes to ``path`` through an atomic same-directory replacement."""
195
+ resolved = path.resolve()
196
+ resolved.parent.mkdir(parents=True, exist_ok=True)
197
+ temp_path = resolved.with_name(f".{resolved.name}.{secrets.token_hex(8)}.tmp")
198
+ try:
199
+ with temp_path.open("xb") as handle:
200
+ handle.write(payload)
201
+ handle.flush()
202
+ os.fsync(handle.fileno())
203
+ os.replace(temp_path, resolved)
204
+ finally:
205
+ if temp_path.exists():
206
+ temp_path.unlink()
207
+
208
+
209
+ def read_csv(path: Path) -> Table:
210
+ """Read a CSV using conservative string-preserving defaults."""
211
+ return read_table_csv(path)
212
+
213
+
214
+ def _csv_bytes_after_fixes(path: Path, fixes: list[CellFix]) -> bytes:
215
+ """Validate fixes against a CSV and return the mutated CSV bytes."""
216
+ df = read_csv(path)
217
+ for fix in fixes:
218
+ if fix.operation != "update":
219
+ raise ValueError(f"Unsupported repair operation '{fix.operation}' for row {fix.row}.")
220
+ if fix.column not in column_names(df):
221
+ raise ValueError(f"Column '{fix.column}' not found in '{path}'.")
222
+ if fix.row < 0 or fix.row >= row_count(df):
223
+ raise ValueError(f"Row {fix.row} is out of bounds for '{path}'.")
224
+
225
+ current_value = cell_value(df, fix.row, fix.column)
226
+ if current_value != fix.old_value:
227
+ raise ValueError(
228
+ f"Refusing to apply stale fix for row {fix.row}, column '{fix.column}': "
229
+ f"expected '{fix.old_value}', found '{current_value}'."
230
+ )
231
+ set_cell_value(df, fix.row, fix.column, fix.new_value)
232
+
233
+ return table_to_csv_bytes(df)
234
+
235
+
236
+ def apply_fixes_to_csv(path: Path, fixes: list[CellFix]) -> str:
237
+ """Atomically apply ordered cell fixes to a CSV and return post-state SHA-256."""
238
+ payload = _csv_bytes_after_fixes(path, fixes)
239
+ _atomic_write_bytes(path, payload)
240
+ return hashlib.sha256(payload).hexdigest()
241
+
242
+
243
+ def _lock_path_for(source_path: Path) -> Path:
244
+ """Return the filesystem lock path for a source file."""
245
+ digest = hashlib.sha256(str(source_path.resolve()).encode("utf-8")).hexdigest()[:24]
246
+ return source_path.resolve().parent / ".dataforge" / "locks" / f"{digest}.lock"
247
+
248
+
249
+ @contextmanager
250
+ def source_path_lock(
251
+ source_path: Path,
252
+ *,
253
+ timeout_seconds: float = 5.0,
254
+ stale_after_seconds: float = 300.0,
255
+ ) -> Iterator[None]:
256
+ """Acquire an exclusive lock for a source path using an atomic lock file."""
257
+ lock_path = _lock_path_for(source_path)
258
+ lock_path.parent.mkdir(parents=True, exist_ok=True)
259
+ deadline = time.monotonic() + timeout_seconds
260
+ while True:
261
+ try:
262
+ fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
263
+ try:
264
+ payload = f"{os.getpid()} {datetime.now(UTC).isoformat()}\n".encode()
265
+ os.write(fd, payload)
266
+ finally:
267
+ os.close(fd)
268
+ break
269
+ except FileExistsError as exc:
270
+ try:
271
+ age = time.time() - lock_path.stat().st_mtime
272
+ except OSError:
273
+ age = 0.0
274
+ if age > stale_after_seconds:
275
+ try:
276
+ lock_path.unlink()
277
+ continue
278
+ except OSError:
279
+ pass
280
+ if time.monotonic() >= deadline:
281
+ raise TransactionApplyError(
282
+ f"Timed out waiting for DataForge source lock: {source_path.resolve()}"
283
+ ) from exc
284
+ time.sleep(0.05)
285
+
286
+ try:
287
+ yield
288
+ finally:
289
+ with suppress(FileNotFoundError):
290
+ lock_path.unlink()
291
+
292
+
293
+ def _write_snapshot_once(snapshot_path: Path, source_bytes: bytes) -> None:
294
+ """Write an immutable snapshot and fail if the transaction id already exists."""
295
+ snapshot_path.parent.mkdir(parents=True, exist_ok=True)
296
+ try:
297
+ with snapshot_path.open("xb") as handle:
298
+ handle.write(source_bytes)
299
+ handle.flush()
300
+ os.fsync(handle.fileno())
301
+ except FileExistsError as exc:
302
+ raise TransactionApplyError(
303
+ f"Transaction snapshot already exists: {snapshot_path}"
304
+ ) from exc
305
+
306
+
307
+ def create_repair_transaction(
308
+ path: Path,
309
+ fixes: list[ProposedFix],
310
+ source_bytes: bytes,
311
+ *,
312
+ txn_id: str | None = None,
313
+ ) -> tuple[RepairTransaction, Path]:
314
+ """Create an unapplied transaction journal and immutable source snapshot."""
315
+ resolved_path = path.resolve()
316
+ transaction_id = txn_id or generate_txn_id()
317
+ snapshot_path = snapshot_path_for(resolved_path, transaction_id)
318
+ _write_snapshot_once(snapshot_path, source_bytes)
319
+
320
+ transaction = RepairTransaction(
321
+ txn_id=transaction_id,
322
+ created_at=datetime.now(UTC),
323
+ source_path=str(resolved_path),
324
+ source_sha256=sha256_bytes(source_bytes),
325
+ source_snapshot_path=str(snapshot_path.resolve()),
326
+ fixes=[proposal.fix for proposal in fixes],
327
+ applied=False,
328
+ )
329
+ try:
330
+ log_path = append_created_transaction(transaction)
331
+ except Exception:
332
+ snapshot_path.unlink(missing_ok=True)
333
+ raise
334
+ return transaction, log_path
335
+
336
+
337
+ def apply_transaction(
338
+ path: Path,
339
+ fixes: list[ProposedFix],
340
+ source_bytes: bytes,
341
+ *,
342
+ txn_id: str | None = None,
343
+ ) -> str:
344
+ """Journal, snapshot, atomically apply fixes, and restore bytes on failure."""
345
+ resolved_path = path.resolve()
346
+ with source_path_lock(resolved_path):
347
+ current_bytes = resolved_path.read_bytes()
348
+ if current_bytes != source_bytes:
349
+ raise TransactionApplyError(
350
+ "Refusing to apply repairs because the source file changed after detection."
351
+ )
352
+
353
+ with repair_stage_span("dataforge.repair.transaction.create", fixes_count=len(fixes)):
354
+ transaction, log_path = create_repair_transaction(
355
+ resolved_path,
356
+ fixes,
357
+ source_bytes,
358
+ txn_id=txn_id,
359
+ )
360
+ try:
361
+ with repair_stage_span("dataforge.repair.transaction.apply", fixes_count=len(fixes)):
362
+ post_sha256 = apply_fixes_to_csv(
363
+ resolved_path,
364
+ [proposal.fix for proposal in fixes],
365
+ )
366
+ append_applied_event(log_path, transaction.txn_id, post_sha256=post_sha256)
367
+ except Exception as exc:
368
+ _atomic_write_bytes(resolved_path, source_bytes)
369
+ if sha256_file(resolved_path) != transaction.source_sha256:
370
+ raise TransactionApplyError(
371
+ "Apply failed and the source file could not be restored to original bytes."
372
+ ) from exc
373
+ raise
374
+
375
+ return transaction.txn_id
376
+
377
+
378
+ def _build_retry_context(issue: Issue, attempts: list[RepairAttempt]) -> RetryContext:
379
+ """Build retry hints from previous failed attempts."""
380
+ rejected_values = frozenset(
381
+ attempt.fix.fix.new_value
382
+ for attempt in attempts
383
+ if attempt.fix is not None and attempt.status in {"denied", "rejected", "unknown"}
384
+ )
385
+ hints: list[str] = []
386
+ for attempt in attempts:
387
+ hints.append(attempt.reason)
388
+ hints.extend(attempt.unsat_core)
389
+ return RetryContext(
390
+ issue=issue,
391
+ previous_attempts=tuple(attempts),
392
+ rejected_values=rejected_values,
393
+ hints=tuple(hints),
394
+ )
395
+
396
+
397
+ def propose_repairs(
398
+ issues: list[Issue],
399
+ path: Path,
400
+ working_df: TableLike,
401
+ schema: Schema | None,
402
+ *,
403
+ allow_llm: bool,
404
+ model: str,
405
+ allow_pii: bool,
406
+ confirm_pii: bool,
407
+ confirm_escalations: bool,
408
+ interactive: bool,
409
+ escalation_resolver: EscalationResolver | None = None,
410
+ ) -> tuple[list[ProposedFix], list[list[RepairAttempt]]]:
411
+ """Run repairers and gates issue-by-issue against a working dataframe."""
412
+ with repair_stage_span("dataforge.repair.repairers.build", allow_llm=allow_llm):
413
+ repairers = build_repairers(
414
+ cache_dir=cache_dir_for(path),
415
+ allow_llm=allow_llm,
416
+ model=model,
417
+ )
418
+ safety_filter = SafetyFilter()
419
+ verifier = SMTVerifier()
420
+ safety_context = SafetyContext(
421
+ allow_pii=allow_pii,
422
+ confirm_pii=confirm_pii,
423
+ confirm_escalations=confirm_escalations,
424
+ )
425
+
426
+ accepted_fixes: list[ProposedFix] = []
427
+ attempt_groups: list[list[RepairAttempt]] = []
428
+
429
+ for issue in issues:
430
+ attempts: list[RepairAttempt] = []
431
+ repairer = repairers.get(issue.issue_type)
432
+ if repairer is None:
433
+ attempts.append(
434
+ RepairAttempt(
435
+ issue=issue,
436
+ attempt_number=1,
437
+ status="attempted_not_fixed",
438
+ reason="No repairer is registered for this issue type.",
439
+ )
440
+ )
441
+ attempt_groups.append(attempts)
442
+ continue
443
+
444
+ accepted = False
445
+ retry_context = RetryContext(issue=issue)
446
+ for attempt_number in range(1, 4):
447
+ candidate = repairer.propose(issue, working_df, schema, retry_context=retry_context)
448
+ if candidate is None:
449
+ attempts.append(
450
+ RepairAttempt(
451
+ issue=issue,
452
+ attempt_number=attempt_number,
453
+ status="attempted_not_fixed",
454
+ reason="No repair proposal was available for this issue.",
455
+ )
456
+ )
457
+ break
458
+
459
+ preferred = safety_filter.choose_preferred([candidate], schema, safety_context)
460
+ safety_result = safety_filter.evaluate(preferred, schema, safety_context)
461
+ if (
462
+ safety_result.verdict == SafetyVerdict.ESCALATE
463
+ and interactive
464
+ and escalation_resolver is not None
465
+ ):
466
+ safety_context, safety_result = escalation_resolver(
467
+ preferred,
468
+ schema,
469
+ safety_context,
470
+ safety_filter,
471
+ safety_result,
472
+ )
473
+
474
+ if safety_result.verdict == SafetyVerdict.DENY:
475
+ attempts.append(
476
+ RepairAttempt(
477
+ issue=issue,
478
+ attempt_number=attempt_number,
479
+ fix=preferred,
480
+ status="denied",
481
+ reason=safety_result.reason,
482
+ )
483
+ )
484
+ retry_context = _build_retry_context(issue, attempts)
485
+ continue
486
+
487
+ if safety_result.verdict == SafetyVerdict.ESCALATE:
488
+ attempts.append(
489
+ RepairAttempt(
490
+ issue=issue,
491
+ attempt_number=attempt_number,
492
+ fix=preferred,
493
+ status="escalated",
494
+ reason=safety_result.reason,
495
+ )
496
+ )
497
+ break
498
+
499
+ with repair_stage_span(
500
+ "dataforge.repair.verifier.verify",
501
+ issue_type=issue.issue_type,
502
+ row=issue.row,
503
+ ):
504
+ verifier_result = verifier.verify(working_df, [preferred], schema)
505
+ if verifier_result.verdict == VerificationVerdict.ACCEPT:
506
+ accepted_fixes.append(preferred)
507
+ set_cell_value(
508
+ working_df,
509
+ preferred.fix.row,
510
+ preferred.fix.column,
511
+ preferred.fix.new_value,
512
+ )
513
+ attempts.append(
514
+ RepairAttempt(
515
+ issue=issue,
516
+ attempt_number=attempt_number,
517
+ fix=preferred,
518
+ status="accepted",
519
+ reason=verifier_result.reason,
520
+ )
521
+ )
522
+ accepted = True
523
+ break
524
+
525
+ attempts.append(
526
+ RepairAttempt(
527
+ issue=issue,
528
+ attempt_number=attempt_number,
529
+ fix=preferred,
530
+ status=(
531
+ "rejected"
532
+ if verifier_result.verdict == VerificationVerdict.REJECT
533
+ else "unknown"
534
+ ),
535
+ reason=verifier_result.reason,
536
+ unsat_core=verifier_result.unsat_core,
537
+ )
538
+ )
539
+ retry_context = _build_retry_context(issue, attempts)
540
+
541
+ if (
542
+ not accepted
543
+ and attempts
544
+ and attempts[-1].status not in {"attempted_not_fixed", "escalated"}
545
+ ):
546
+ last_reason = attempts[-1].reason
547
+ attempts[-1] = attempts[-1].model_copy(
548
+ update={
549
+ "status": "attempted_not_fixed",
550
+ "reason": (
551
+ f"Issue was attempted but not fixed after {len(attempts)} attempt(s). "
552
+ f"Last failure: {last_reason}"
553
+ ),
554
+ }
555
+ )
556
+ attempt_groups.append(attempts)
557
+
558
+ return accepted_fixes, attempt_groups
559
+
560
+
561
+ def _verified_fixes(
562
+ fixes: list[ProposedFix],
563
+ attempt_groups: list[list[RepairAttempt]],
564
+ ) -> list[VerifiedFix]:
565
+ """Build public verified fix payloads using accepted attempt reasons."""
566
+ accepted_reasons: dict[tuple[int, str, str], str] = {}
567
+ for attempts in attempt_groups:
568
+ for attempt in attempts:
569
+ if attempt.status == "accepted" and attempt.fix is not None:
570
+ fix = attempt.fix.fix
571
+ accepted_reasons[(fix.row, fix.column, fix.new_value)] = attempt.reason
572
+
573
+ return [
574
+ VerifiedFix(
575
+ **CandidateFix.from_proposed(fix).model_dump(),
576
+ verifier_reason=accepted_reasons.get(
577
+ (fix.fix.row, fix.fix.column, fix.fix.new_value),
578
+ "Accepted by verifier.",
579
+ ),
580
+ )
581
+ for fix in fixes
582
+ ]
583
+
584
+
585
+ def _failed_attempts(attempt_groups: list[list[RepairAttempt]]) -> list[RepairFailure]:
586
+ """Return failures for issue groups whose final status was not accepted."""
587
+ return [
588
+ RepairFailure.from_attempts(attempts)
589
+ for attempts in attempt_groups
590
+ if attempts and attempts[-1].status != "accepted"
591
+ ]
592
+
593
+
594
+ def run_repair_pipeline(request: RepairPipelineRequest) -> RepairPipelineResult:
595
+ """Run the public repair pipeline from detection through optional apply."""
596
+ source_path = request.source_path.resolve()
597
+ source_bytes = source_path.read_bytes()
598
+ df = read_csv(source_path)
599
+ with repair_stage_span("dataforge.repair.detect", row_count=row_count(df)):
600
+ issues = run_all_detectors(df, request.repair_schema)
601
+ with repair_stage_span("dataforge.repair.propose", issue_count=len(issues)):
602
+ accepted_fixes, attempt_groups = propose_repairs(
603
+ issues,
604
+ source_path,
605
+ copy_table(df),
606
+ request.repair_schema,
607
+ allow_llm=request.allow_llm,
608
+ model=request.model,
609
+ allow_pii=request.allow_pii,
610
+ confirm_pii=request.confirm_pii,
611
+ confirm_escalations=request.confirm_escalations,
612
+ interactive=request.interactive,
613
+ )
614
+
615
+ with repair_stage_span("dataforge.repair.safety.batch", fixes_count=len(accepted_fixes)):
616
+ batch_safety = SafetyFilter().evaluate_batch(accepted_fixes)
617
+ failures = _failed_attempts(attempt_groups)
618
+ transaction: RepairTransaction | None = None
619
+ txn_id: str | None = None
620
+ post_sha256: str | None = None
621
+ applied = False
622
+ reason = "No accepted fixes were produced."
623
+
624
+ if batch_safety.verdict != SafetyVerdict.ALLOW:
625
+ accepted_fixes = []
626
+ reason = batch_safety.reason
627
+ elif request.mode == "apply" and accepted_fixes:
628
+ txn_id = apply_transaction(source_path, accepted_fixes, source_bytes)
629
+ post_sha256 = sha256_file(source_path)
630
+ applied = True
631
+ reason = f"Applied {len(accepted_fixes)} fix(es)."
632
+ elif request.create_dry_run_transaction:
633
+ transaction, _log_path = create_repair_transaction(
634
+ source_path, accepted_fixes, source_bytes
635
+ )
636
+ txn_id = transaction.txn_id
637
+ reason = (
638
+ "Dry run completed without mutating the source file."
639
+ if accepted_fixes
640
+ else "No accepted fixes were produced."
641
+ )
642
+ elif accepted_fixes:
643
+ reason = "Dry run completed without mutating the source file."
644
+
645
+ if txn_id is not None and transaction is None:
646
+ # Replaying the log is unnecessary for the public contract here; this
647
+ # minimal receipt is intentionally enough for API callers.
648
+ transaction = None
649
+
650
+ receipt = RepairReceipt(
651
+ mode=request.mode,
652
+ applied=applied,
653
+ reversible=True,
654
+ source_path=str(source_path),
655
+ source_sha256=sha256_bytes(source_bytes),
656
+ post_sha256=post_sha256,
657
+ txn_id=txn_id,
658
+ allowed_columns=column_names(df),
659
+ valid_rows=list(range(row_count(df))),
660
+ issues_count=len(issues),
661
+ fixes_count=len(accepted_fixes),
662
+ reason=reason,
663
+ )
664
+ return RepairPipelineResult(
665
+ receipt=receipt,
666
+ issues=issues,
667
+ fixes=_verified_fixes(accepted_fixes, attempt_groups),
668
+ failures=failures,
669
+ transaction=transaction,
670
+ )
dataforge/env/__init__.py CHANGED
@@ -1 +1,22 @@
1
- """Environment package scaffolding for DataForge."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DataForge RL environment OpenEnv-compatible data-quality environment.
2
+
3
+ Public API:
4
+ DataForgeEnv — Core environment with reset/step/state/close.
5
+ ResetResult — Return type of reset().
6
+ StepResult — Return type of step().
7
+ EnvState — State snapshot from state().
8
+ DataForgeObservation — Agent-visible observation.
9
+ ToolResult — Structured result from each action.
10
+ """
11
+
12
+ from dataforge.env.environment import DataForgeEnv, EnvState, ResetResult, StepResult
13
+ from dataforge.env.observation import DataForgeObservation, ToolResult
14
+
15
+ __all__ = [
16
+ "DataForgeEnv",
17
+ "DataForgeObservation",
18
+ "EnvState",
19
+ "ResetResult",
20
+ "StepResult",
21
+ "ToolResult",
22
+ ]
dataforge/env/environment.py ADDED
@@ -0,0 +1,884 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv-compatible DataForge RL environment.
2
+
3
+ Core environment implementing reset/step/state/close for data-quality
4
+ detection, diagnosis, and repair with typed tool-use actions.
5
+
6
+ No LLM calls. No disk writes. Dataset state is in-memory per episode.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ import random
13
+ import re
14
+ import uuid
15
+ from difflib import SequenceMatcher
16
+ from pathlib import Path
17
+ from typing import Any, cast
18
+
19
+ import duckdb
20
+ import pandas as pd
21
+ import sqlglot
22
+ import sqlglot.expressions as sqlglot_exp
23
+ from pydantic import BaseModel, Field
24
+
25
+ from dataforge.agent.scratchpad import Scratchpad
26
+ from dataforge.agent.tool_actions import (
27
+ Action,
28
+ Diagnose,
29
+ Fix,
30
+ Hypothesis,
31
+ InspectRows,
32
+ PatternMatch,
33
+ RootCause,
34
+ SqlQuery,
35
+ StatTest,
36
+ parse_action,
37
+ )
38
+ from dataforge.detectors import run_all_detectors
39
+ from dataforge.detectors.base import Issue
40
+ from dataforge.env.observation import DataForgeObservation, ToolResult
41
+ from dataforge.env.reward import (
42
+ P_FALSE_POS,
43
+ P_INVALID,
44
+ P_WRONG_FIX,
45
+ R_EXPLORE,
46
+ R_ROOT_CAUSE,
47
+ EpisodeMetrics,
48
+ RewardEngine,
49
+ )
50
+
51
+ logger = logging.getLogger("dataforge.env")
52
+
53
+ __all__ = [
54
+ "DataForgeEnv",
55
+ "EnvState",
56
+ "ResetResult",
57
+ "StepResult",
58
+ ]
59
+
60
+ _FIXTURES_DIR = Path(__file__).resolve().parents[1].parent / "fixtures"
61
+ _DEFAULT_CSV = _FIXTURES_DIR / "hospital_10rows.csv"
62
+ _DEFAULT_SCHEMA = _FIXTURES_DIR / "hospital_schema.yaml"
63
+ _MAX_STEPS = 30
64
+ _MAX_RESULT_ROWS = 20
65
+ _TOOL_HISTORY_LIMIT = 5
66
+ _NOISE_EPSILON = 0.15
67
+ _BLOCKED_SQL_FRAGMENTS = (
68
+ "attach",
69
+ "call ",
70
+ "copy ",
71
+ "detach",
72
+ "duckdb_extensions",
73
+ "filename",
74
+ "from_csv_auto",
75
+ "glob(",
76
+ "http://",
77
+ "https://",
78
+ "httpfs",
79
+ "install",
80
+ "load ",
81
+ "mysql_scan",
82
+ "parquet_scan",
83
+ "postgres_scan",
84
+ "pragma",
85
+ "read_csv",
86
+ "read_json",
87
+ "read_parquet",
88
+ "s3://",
89
+ "sqlite_scan",
90
+ )
91
+
92
+
93
+ # ═══════════════════════════════════════════════════════════════════════════
94
+ # Result models
95
+ # ═══════════════════════════════════════════════════════════════════════════
96
+
97
+
98
+ class ResetResult(BaseModel):
99
+ """Result of env.reset()."""
100
+
101
+ observation: DataForgeObservation
102
+ info: dict[str, Any] = Field(default_factory=dict)
103
+
104
+
105
+ class StepResult(BaseModel):
106
+ """Result of env.step()."""
107
+
108
+ observation: DataForgeObservation
109
+ reward: float = 0.0
110
+ done: bool = False
111
+ info: dict[str, Any] = Field(default_factory=dict)
112
+
113
+
114
+ class EnvState(BaseModel):
115
+ """Internal environment state snapshot."""
116
+
117
+ episode_id: str = ""
118
+ step_count: int = 0
119
+ task_id: str = ""
120
+ issues_detected: int = 0
121
+ issues_fixed: int = 0
122
+ false_positives: int = 0
123
+ total_issues: int = 0
124
+ is_done: bool = False
125
+
126
+
127
+ # ═══════════════════════════════════════════════════════════════════════════
128
+ # Environment
129
+ # ═══════════════════════════════════════════════════════════════════════════
130
+
131
+
132
+ class DataForgeEnv:
133
+ """OpenEnv-compatible RL environment for data quality repair.
134
+
135
+ Core API: ``reset()``, ``step()``, ``state()``, ``close()`` (no-op).
136
+
137
+ Example::
138
+
139
+ >>> env = DataForgeEnv()
140
+ >>> result = env.reset(seed=42)
141
+ >>> result.observation.done
142
+ False
143
+ """
144
+
145
+ def __init__(self, max_steps: int = _MAX_STEPS) -> None:
146
+ self._max_steps = max_steps
147
+ self._episode_id = ""
148
+ self._step_count = 0
149
+ self._df: pd.DataFrame = pd.DataFrame()
150
+ self._ground_truth: list[Issue] = []
151
+ self._found_issues: list[dict[str, Any]] = []
152
+ self._fixed_issues: list[dict[str, Any]] = []
153
+ self._false_positives = 0
154
+ self._cumulative_reward = 0.0
155
+ self._is_done = False
156
+ self._inspected_rows: set[int] = set()
157
+ self._noisy = False
158
+ self._noise_rng: random.Random | None = None
159
+ self._scratchpad = Scratchpad()
160
+ self._tool_history: list[ToolResult] = []
161
+ self._reward_engine = RewardEngine()
162
+ self._schema_info: dict[str, str] = {}
163
+ self._causal_dag_cache: Any = None
164
+ self._root_cause_labels: set[int] = set()
165
+
166
+ # ── Core API ──────────────────────────────────────────────────────────
167
+
168
+ def reset(self, seed: int | None = None, *, noisy: bool = False) -> ResetResult:
169
+ """Reset the environment for a new episode.
170
+
171
+ Args:
172
+ seed: Optional RNG seed for deterministic episodes.
173
+ noisy: If True, enable observation noise (epsilon=0.15).
174
+
175
+ Returns:
176
+ ResetResult with initial observation.
177
+ """
178
+ self._episode_id = str(uuid.uuid4())
179
+ self._step_count = 0
180
+ self._found_issues = []
181
+ self._fixed_issues = []
182
+ self._false_positives = 0
183
+ self._cumulative_reward = 0.0
184
+ self._is_done = False
185
+ self._inspected_rows = set()
186
+ self._scratchpad.reset()
187
+ self._tool_history = []
188
+ self._causal_dag_cache = None
189
+ self._root_cause_labels = set()
190
+ self._noisy = noisy
191
+ self._noise_rng = random.Random(seed if seed is not None else 0) if noisy else None
192
+
193
+ # Load fixture dataset
194
+ self._df = pd.read_csv(_DEFAULT_CSV, dtype=str)
195
+ self._schema_info = dict.fromkeys(self._df.columns, "str")
196
+ if _DEFAULT_SCHEMA.exists():
197
+ import yaml
198
+
199
+ with open(_DEFAULT_SCHEMA, encoding="utf-8") as f:
200
+ schema_data = yaml.safe_load(f)
201
+ if schema_data and "columns" in schema_data:
202
+ self._schema_info = schema_data["columns"]
203
+
204
+ # Run detectors for hidden ground truth
205
+ self._ground_truth = run_all_detectors(self._df)
206
+ logger.info(
207
+ "Episode %s: %d rows, %d ground-truth issues",
208
+ self._episode_id[:8],
209
+ len(self._df),
210
+ len(self._ground_truth),
211
+ )
212
+
213
+ # Initial observation with first 5 rows
214
+ initial_rows = cast(list[dict[str, Any]], self._df.head(5).to_dict(orient="records"))
215
+ obs = DataForgeObservation(
216
+ visible_rows=initial_rows,
217
+ step_budget_remaining=self._max_steps,
218
+ scratchpad_summary=self._scratchpad.summary(),
219
+ metadata={
220
+ "episode_id": self._episode_id,
221
+ "total_rows": len(self._df),
222
+ "total_columns": len(self._df.columns),
223
+ "schema": self._schema_info,
224
+ },
225
+ )
226
+ return ResetResult(observation=obs, info={"episode_id": self._episode_id})
227
+
228
+ def step(self, action: Action | dict[str, Any]) -> StepResult:
229
+ """Execute one agent action and return the result.
230
+
231
+ Args:
232
+ action: A typed Action model or raw dict to be parsed.
233
+
234
+ Returns:
235
+ StepResult with observation, reward, and done flag.
236
+ """
237
+ if self._is_done:
238
+ return self._terminal_result(0.0)
239
+
240
+ self._step_count += 1
241
+
242
+ # Parse if raw dict
243
+ if isinstance(action, dict):
244
+ try:
245
+ action = parse_action(action)
246
+ except Exception as exc:
247
+ return self._error_step(str(exc))
248
+
249
+ # Dispatch
250
+ try:
251
+ tool_result, reward = self._dispatch(action)
252
+ except Exception as exc:
253
+ logger.exception("Action dispatch error at step %d", self._step_count)
254
+ return self._error_step(str(exc))
255
+
256
+ # Late-step penalty
257
+ reward += self._reward_engine.compute_late_penalty(self._step_count, self._max_steps)
258
+
259
+ # Accumulate
260
+ self._cumulative_reward += reward
261
+
262
+ # Record in history
263
+ self._tool_history.append(tool_result)
264
+ if len(self._tool_history) > _TOOL_HISTORY_LIMIT:
265
+ self._tool_history = self._tool_history[-_TOOL_HISTORY_LIMIT:]
266
+
267
+ # Check termination
268
+ done = self._step_count >= self._max_steps
269
+ if done:
270
+ self._is_done = True
271
+ terminal = self._compute_terminal()
272
+ self._cumulative_reward = max(self._cumulative_reward, terminal)
273
+
274
+ obs = DataForgeObservation(
275
+ visible_rows=tool_result.data
276
+ if tool_result.action_type == "INSPECT_ROWS" and tool_result.success
277
+ else None,
278
+ scratchpad_summary=self._scratchpad.summary(),
279
+ step_budget_remaining=max(0, self._max_steps - self._step_count),
280
+ tool_usage_history=list(self._tool_history),
281
+ latest_result=tool_result,
282
+ done=done,
283
+ reward=reward,
284
+ cumulative_reward=self._cumulative_reward,
285
+ )
286
+ return StepResult(observation=obs, reward=reward, done=done)
287
+
288
+ def state(self) -> EnvState:
289
+ """Return current internal state snapshot."""
290
+ return EnvState(
291
+ episode_id=self._episode_id,
292
+ step_count=self._step_count,
293
+ issues_detected=len(self._found_issues),
294
+ issues_fixed=len(self._fixed_issues),
295
+ false_positives=self._false_positives,
296
+ total_issues=len(self._ground_truth),
297
+ is_done=self._is_done,
298
+ )
299
+
300
+ def close(self) -> None:
301
+ """No-op. Retained for OpenEnv container compatibility."""
302
+
303
+ # ── Dispatch ─────────────────��────────────────────────────────────────
304
+
305
+ def _dispatch(self, action: Action) -> tuple[ToolResult, float]:
306
+ """Route action to handler. Returns (tool_result, step_reward)."""
307
+ if isinstance(action, InspectRows):
308
+ return self._handle_inspect(action)
309
+ if isinstance(action, SqlQuery):
310
+ return self._handle_sql(action)
311
+ if isinstance(action, StatTest):
312
+ return self._handle_stat(action)
313
+ if isinstance(action, PatternMatch):
314
+ return self._handle_pattern(action)
315
+ if isinstance(action, Hypothesis):
316
+ return self._handle_hypothesis(action)
317
+ if isinstance(action, RootCause):
318
+ return self._handle_root_cause(action)
319
+ if isinstance(action, Diagnose):
320
+ return self._handle_diagnose(action)
321
+ if isinstance(action, Fix):
322
+ return self._handle_fix(action)
323
+ return ToolResult(
324
+ action_type="UNKNOWN",
325
+ success=False,
326
+ error={"verdict": "error", "reason": "Unknown action type"},
327
+ ), P_INVALID
328
+
329
+ # ── Action handlers ───────────────────────────────────────────────────
330
+
331
+ def _handle_inspect(self, action: InspectRows) -> tuple[ToolResult, float]:
332
+ """Handle INSPECT_ROWS: return dataset rows."""
333
+ valid_indices = [i for i in action.row_indices if 0 <= i < len(self._df)]
334
+ if not valid_indices:
335
+ return ToolResult(
336
+ action_type="INSPECT_ROWS",
337
+ success=False,
338
+ error={"verdict": "error", "reason": "No valid row indices"},
339
+ ), P_INVALID
340
+
341
+ # Apply 20-row cap
342
+ valid_indices = valid_indices[:20]
343
+ rows = self._df.iloc[valid_indices]
344
+ if action.column_names:
345
+ valid_cols = [c for c in action.column_names if c in self._df.columns]
346
+ if valid_cols:
347
+ rows = rows[valid_cols]
348
+
349
+ row_dicts = cast(list[dict[str, Any]], rows.to_dict(orient="records"))
350
+ for i, idx in enumerate(valid_indices[: len(row_dicts)]):
351
+ row_dicts[i]["_row_index"] = idx
352
+
353
+ # Noise injection
354
+ if self._noisy and self._noise_rng:
355
+ row_dicts = self._inject_noise(row_dicts)
356
+
357
+ # Exploration bonus
358
+ new_indices = set(valid_indices) - self._inspected_rows
359
+ self._inspected_rows.update(valid_indices)
360
+ gt_rows = {issue.row for issue in self._ground_truth}
361
+ found_rows = {f["row"] for f in self._found_issues}
362
+ bonus = self._reward_engine.compute_exploration_bonus(
363
+ new_indices,
364
+ self._inspected_rows,
365
+ len(self._df),
366
+ gt_rows,
367
+ found_rows,
368
+ )
369
+ return ToolResult(action_type="INSPECT_ROWS", success=True, data=row_dicts), bonus
370
+
371
+ def _handle_sql(self, action: SqlQuery) -> tuple[ToolResult, float]:
372
+ """Handle SQL_QUERY: execute read-only SQL via DuckDB."""
373
+ # Validate read-only
374
+ try:
375
+ parsed = [stmt for stmt in sqlglot.parse(action.query) if stmt is not None]
376
+ except sqlglot.errors.ParseError as exc:
377
+ return ToolResult(
378
+ action_type="SQL_QUERY",
379
+ success=False,
380
+ error={
381
+ "verdict": "error",
382
+ "reason": str(exc),
383
+ "suggested_constraint": "Use valid SQL syntax",
384
+ },
385
+ ), P_INVALID
386
+
387
+ if len(parsed) != 1:
388
+ return ToolResult(
389
+ action_type="SQL_QUERY",
390
+ success=False,
391
+ error={
392
+ "verdict": "rejected",
393
+ "reason": "Exactly one SELECT statement is allowed.",
394
+ "suggested_constraint": "Use a single read-only SELECT statement.",
395
+ },
396
+ ), P_INVALID
397
+
398
+ normalized_query = f" {action.query.lower()} "
399
+ blocked = next(
400
+ (fragment for fragment in _BLOCKED_SQL_FRAGMENTS if fragment in normalized_query),
401
+ None,
402
+ )
403
+ if blocked is not None:
404
+ return ToolResult(
405
+ action_type="SQL_QUERY",
406
+ success=False,
407
+ error={
408
+ "verdict": "rejected",
409
+ "reason": "SQL_QUERY may only read from the registered data relation.",
410
+ "suggested_constraint": "Query the in-memory data table without file, network, extension, or table functions.",
411
+ },
412
+ ), P_INVALID
413
+
414
+ for stmt in parsed:
415
+ if stmt.key not in ("select",):
416
+ return ToolResult(
417
+ action_type="SQL_QUERY",
418
+ success=False,
419
+ error={
420
+ "verdict": "rejected",
421
+ "reason": f"Only SELECT queries allowed, got {stmt.key.upper()}",
422
+ "suggested_constraint": "Use SELECT statements only",
423
+ },
424
+ ), P_INVALID
425
+
426
+ for table in stmt.find_all(sqlglot_exp.Table):
427
+ if table.name.lower() != "data":
428
+ return ToolResult(
429
+ action_type="SQL_QUERY",
430
+ success=False,
431
+ error={
432
+ "verdict": "rejected",
433
+ "reason": (
434
+ "SQL_QUERY may only reference the registered data relation; "
435
+ f"got '{table.name}'."
436
+ ),
437
+ "suggested_constraint": "Use FROM data for tabular queries.",
438
+ },
439
+ ), P_INVALID
440
+
441
+ try:
442
+ conn = duckdb.connect(":memory:")
443
+ conn.register("data", self._df)
444
+ result_df = conn.execute(action.query).fetchdf()
445
+ conn.close()
446
+ rows = result_df.head(_MAX_RESULT_ROWS).to_dict(orient="records")
447
+ return ToolResult(action_type="SQL_QUERY", success=True, data=rows), 0.0
448
+ except duckdb.Error as exc:
449
+ return ToolResult(
450
+ action_type="SQL_QUERY",
451
+ success=False,
452
+ error={"verdict": "error", "reason": str(exc)},
453
+ ), P_INVALID
454
+
455
+ def _handle_stat(self, action: StatTest) -> tuple[ToolResult, float]:
456
+ """Handle STAT_TEST: run zscore/iqr/ks on a column."""
457
+ if action.column not in self._df.columns:
458
+ return ToolResult(
459
+ action_type="STAT_TEST",
460
+ success=False,
461
+ error={"verdict": "error", "reason": f"Column '{action.column}' not found"},
462
+ ), P_INVALID
463
+
464
+ try:
465
+ col = pd.to_numeric(self._df[action.column], errors="coerce").dropna()
466
+ if len(col) == 0:
467
+ return ToolResult(
468
+ action_type="STAT_TEST",
469
+ success=False,
470
+ error={
471
+ "verdict": "error",
472
+ "reason": f"No numeric values in column '{action.column}'",
473
+ },
474
+ ), P_INVALID
475
+ except Exception as exc:
476
+ return ToolResult(
477
+ action_type="STAT_TEST",
478
+ success=False,
479
+ error={"verdict": "error", "reason": str(exc)},
480
+ ), P_INVALID
481
+
482
+ from scipy import stats as scipy_stats # type: ignore[import-untyped]
483
+
484
+ if action.test_type == "zscore":
485
+ zscores = scipy_stats.zscore(col)
486
+ threshold = action.threshold or 3.0
487
+ outliers = col.index[abs(zscores) > threshold].tolist()
488
+ data = {
489
+ "test": "zscore",
490
+ "threshold": threshold,
491
+ "outlier_indices": outliers,
492
+ "n_outliers": len(outliers),
493
+ "mean": float(col.mean()),
494
+ "std": float(col.std()),
495
+ }
496
+ elif action.test_type == "iqr":
497
+ q1, q3 = float(col.quantile(0.25)), float(col.quantile(0.75))
498
+ iqr_val = q3 - q1
499
+ factor = action.threshold or 1.5
500
+ lower, upper = q1 - factor * iqr_val, q3 + factor * iqr_val
501
+ outliers = col.index[(col < lower) | (col > upper)].tolist()
502
+ data = {
503
+ "test": "iqr",
504
+ "q1": q1,
505
+ "q3": q3,
506
+ "iqr": iqr_val,
507
+ "lower": lower,
508
+ "upper": upper,
509
+ "outlier_indices": outliers,
510
+ }
511
+ elif action.test_type == "ks":
512
+ stat_val, p_val = scipy_stats.kstest(
513
+ col, "norm", args=(float(col.mean()), float(col.std()))
514
+ )
515
+ data = {
516
+ "test": "ks",
517
+ "statistic": float(stat_val),
518
+ "p_value": float(p_val),
519
+ "normal": p_val > 0.05,
520
+ }
521
+ else:
522
+ return ToolResult(
523
+ action_type="STAT_TEST",
524
+ success=False,
525
+ error={"verdict": "error", "reason": f"Unknown test type: {action.test_type}"},
526
+ ), P_INVALID
527
+
528
+ return ToolResult(action_type="STAT_TEST", success=True, data=data), 0.0
529
+
530
+ def _handle_pattern(self, action: PatternMatch) -> tuple[ToolResult, float]:
531
+ """Handle PATTERN_MATCH: evaluate regex against column values."""
532
+ if action.column not in self._df.columns:
533
+ return ToolResult(
534
+ action_type="PATTERN_MATCH",
535
+ success=False,
536
+ error={"verdict": "error", "reason": f"Column '{action.column}' not found"},
537
+ ), P_INVALID
538
+
539
+ try:
540
+ compiled = re.compile(action.pattern)
541
+ except re.error as exc:
542
+ return ToolResult(
543
+ action_type="PATTERN_MATCH",
544
+ success=False,
545
+ error={"verdict": "error", "reason": f"Invalid regex: {exc}"},
546
+ ), P_INVALID
547
+
548
+ matches: list[dict[str, Any]] = []
549
+ for idx, val in enumerate(self._df[action.column].astype(str)):
550
+ is_match = bool(compiled.search(val))
551
+ if is_match == action.expect_match:
552
+ matches.append({"row": idx, "column": action.column, "value": val})
553
+ return ToolResult(
554
+ action_type="PATTERN_MATCH",
555
+ success=True,
556
+ data={"matches": matches[:_MAX_RESULT_ROWS], "total_matches": len(matches)},
557
+ ), 0.0
558
+
559
+ def _handle_hypothesis(self, action: Hypothesis) -> tuple[ToolResult, float]:
560
+ """Handle HYPOTHESIS: record claim and award root-cause credit."""
561
+ self._scratchpad.add_hypothesis(
562
+ action.claim,
563
+ action.affected_rows,
564
+ action.affected_columns,
565
+ action.root_cause_type,
566
+ )
567
+ # Check for root-cause match against ground truth
568
+ credit = 0.0
569
+ for issue in self._ground_truth:
570
+ if (
571
+ issue.row in action.affected_rows
572
+ and issue.column in action.affected_columns
573
+ and issue.issue_type == action.root_cause_type
574
+ ):
575
+ credit += R_EXPLORE
576
+ data = {"recorded": True, "root_cause_credit": credit}
577
+ return ToolResult(action_type="HYPOTHESIS", success=True, data=data), credit
578
+
579
+ def _handle_root_cause(self, action: RootCause) -> tuple[ToolResult, float]:
580
+ """Handle ROOT_CAUSE: analyze detected issues for minimal roots."""
581
+ if not self._found_issues:
582
+ return ToolResult(
583
+ action_type="ROOT_CAUSE",
584
+ success=False,
585
+ error={"verdict": "error", "reason": "No detected issues are available"},
586
+ ), P_INVALID
587
+
588
+ invalid = [idx for idx in action.error_indices if idx >= len(self._found_issues)]
589
+ if invalid:
590
+ return ToolResult(
591
+ action_type="ROOT_CAUSE",
592
+ success=False,
593
+ error={
594
+ "verdict": "error",
595
+ "reason": f"Detected issue indices out of range: {invalid}",
596
+ },
597
+ ), P_INVALID
598
+
599
+ from dataforge.causal.pc import discover_causal_dag
600
+ from dataforge.causal.root_cause import CausalRootCauseAnalyzer, evidence_from_issue
601
+
602
+ if self._causal_dag_cache is None:
603
+ self._causal_dag_cache = discover_causal_dag(self._df).dag
604
+
605
+ selected = [
606
+ evidence_from_issue(index, self._found_issues[index]) for index in action.error_indices
607
+ ]
608
+ result = CausalRootCauseAnalyzer(self._causal_dag_cache).analyze(selected)
609
+ data = result.model_dump(mode="json")
610
+ reward = self._root_cause_reward(set(result.root_indices))
611
+ return ToolResult(action_type="ROOT_CAUSE", success=True, data=data), reward
612
+
613
+ def _handle_diagnose(self, action: Diagnose) -> tuple[ToolResult, float]:
614
+ """Handle DIAGNOSE: score against ground truth."""
615
+ if action.row < 0 or action.row >= len(self._df):
616
+ return ToolResult(
617
+ action_type="DIAGNOSE",
618
+ success=False,
619
+ error={"verdict": "error", "reason": f"Row {action.row} out of bounds"},
620
+ ), P_INVALID
621
+ if action.column not in self._df.columns:
622
+ return ToolResult(
623
+ action_type="DIAGNOSE",
624
+ success=False,
625
+ error={"verdict": "error", "reason": f"Column '{action.column}' not found"},
626
+ ), P_INVALID
627
+
628
+ # Already reported?
629
+ for found in self._found_issues:
630
+ if found["row"] == action.row and found["column"] == action.column:
631
+ return ToolResult(
632
+ action_type="DIAGNOSE", success=True, data={"result": "already_found"}
633
+ ), 0.0
634
+
635
+ # Match ground truth
636
+ for issue in self._ground_truth:
637
+ if issue.row == action.row and issue.column == action.column:
638
+ type_match = action.issue_type == issue.issue_type
639
+ reward = self._reward_engine.diagnose_reward(type_match)
640
+ self._found_issues.append(
641
+ {"row": action.row, "column": action.column, "type": action.issue_type}
642
+ )
643
+ self._scratchpad.confirm_issue(action.row, action.column, action.issue_type)
644
+ return ToolResult(
645
+ action_type="DIAGNOSE",
646
+ success=True,
647
+ data={"result": "correct", "type_match": type_match},
648
+ ), reward
649
+
650
+ # False positive
651
+ self._false_positives += 1
652
+ return ToolResult(
653
+ action_type="DIAGNOSE", success=True, data={"result": "false_positive"}
654
+ ), P_FALSE_POS
655
+
656
+ def _root_cause_reward(self, root_indices: set[int]) -> float:
657
+ """Return root-cause bonus only when task labels are available."""
658
+ if not self._root_cause_labels:
659
+ return 0.0
660
+ return R_ROOT_CAUSE if root_indices == self._root_cause_labels else 0.0
661
+
662
+ def _handle_fix(self, action: Fix) -> tuple[ToolResult, float]:
663
+ """Handle FIX: validate through safety/SMT, then score."""
664
+ if action.row < 0 or action.row >= len(self._df):
665
+ return ToolResult(
666
+ action_type="FIX",
667
+ success=False,
668
+ error={"verdict": "error", "reason": f"Row {action.row} out of bounds"},
669
+ ), P_INVALID
670
+ if action.column not in self._df.columns:
671
+ return ToolResult(
672
+ action_type="FIX",
673
+ success=False,
674
+ error={"verdict": "error", "reason": f"Column '{action.column}' not found"},
675
+ ), P_INVALID
676
+
677
+ # Already fixed?
678
+ for fixed in self._fixed_issues:
679
+ if fixed["row"] == action.row and fixed["column"] == action.column:
680
+ return ToolResult(
681
+ action_type="FIX", success=True, data={"result": "already_fixed"}
682
+ ), 0.0
683
+
684
+ # Safety filter + SMT verifier (best-effort, no crash on import failure)
685
+ try:
686
+ safety_ok, safety_msg = self._check_safety(action)
687
+ except Exception as exc:
688
+ logger.warning("Safety pipeline failed closed: %s", exc)
689
+ safety_ok = False
690
+ safety_msg = f"Safety pipeline failed closed: {exc}"
691
+ if not safety_ok:
692
+ return ToolResult(
693
+ action_type="FIX",
694
+ success=False,
695
+ error={"verdict": "rejected", "reason": safety_msg},
696
+ ), P_INVALID
697
+
698
+ # Match ground truth
699
+ for issue in self._ground_truth:
700
+ if issue.row == action.row and issue.column == action.column:
701
+ if issue.expected is None:
702
+ return ToolResult(
703
+ action_type="FIX", success=True, data={"result": "detection_only"}
704
+ ), 0.0
705
+
706
+ # Exact match (case-insensitive)
707
+ if action.new_value.strip().lower() == str(issue.expected).lower():
708
+ reward = self._reward_engine.fix_reward(
709
+ exact=True, has_justification=bool(action.justification)
710
+ )
711
+ self._fixed_issues.append(
712
+ {"row": action.row, "column": action.column, "value": action.new_value}
713
+ )
714
+ self._auto_diagnose(action, issue)
715
+ return ToolResult(
716
+ action_type="FIX", success=True, data={"result": "correct"}
717
+ ), reward
718
+
719
+ # Partial: numeric within 1%
720
+ try:
721
+ prov = float(action.new_value.strip())
722
+ exp = float(str(issue.expected))
723
+ rel_err = abs(prov - exp) / abs(exp) if exp != 0 else abs(prov)
724
+ if rel_err < 0.01:
725
+ reward = self._reward_engine.fix_reward(
726
+ exact=False, has_justification=bool(action.justification)
727
+ )
728
+ self._fixed_issues.append(
729
+ {"row": action.row, "column": action.column, "value": action.new_value}
730
+ )
731
+ self._auto_diagnose(action, issue)
732
+ return ToolResult(
733
+ action_type="FIX", success=True, data={"result": "partial_numeric"}
734
+ ), reward
735
+ except (ValueError, TypeError):
736
+ pass
737
+
738
+ # Partial: string similarity >= 85%
739
+ sim = SequenceMatcher(
740
+ None, action.new_value.lower(), str(issue.expected).lower()
741
+ ).ratio()
742
+ if sim >= 0.85:
743
+ reward = self._reward_engine.fix_reward(
744
+ exact=False, has_justification=bool(action.justification)
745
+ )
746
+ self._fixed_issues.append(
747
+ {"row": action.row, "column": action.column, "value": action.new_value}
748
+ )
749
+ self._auto_diagnose(action, issue)
750
+ return ToolResult(
751
+ action_type="FIX", success=True, data={"result": "partial_string"}
752
+ ), reward
753
+
754
+ return ToolResult(
755
+ action_type="FIX", success=True, data={"result": "wrong_value"}
756
+ ), P_WRONG_FIX
757
+
758
+ return ToolResult(
759
+ action_type="FIX", success=True, data={"result": "no_issue_at_location"}
760
+ ), P_WRONG_FIX
761
+
762
+ # ── Helpers ────────────────────────────────────────────────────────────
763
+
764
+ def _check_safety(self, action: Fix) -> tuple[bool, str]:
765
+ """Run SafetyFilter + SMTVerifier. Returns (ok, message)."""
766
+ try:
767
+ from dataforge.repairers.base import ProposedFix
768
+ from dataforge.safety.filter import SafetyContext, SafetyFilter, SafetyVerdict
769
+ from dataforge.transactions.txn import CellFix
770
+ from dataforge.verifier.smt import SMTVerifier, VerificationVerdict
771
+
772
+ old_val = str(self._df.at[action.row, action.column])
773
+ cell_fix = CellFix(
774
+ row=action.row,
775
+ column=action.column,
776
+ old_value=old_val,
777
+ new_value=action.new_value,
778
+ detector_id="agent",
779
+ )
780
+ proposed = ProposedFix(
781
+ fix=cell_fix,
782
+ reason=action.justification,
783
+ confidence=0.8,
784
+ provenance="deterministic",
785
+ )
786
+
787
+ sf = SafetyFilter()
788
+ ctx = SafetyContext()
789
+ sr = sf.evaluate(proposed, None, ctx)
790
+ if sr.verdict == SafetyVerdict.DENY:
791
+ return False, f"Safety filter denied: {sr.reason}"
792
+
793
+ verifier = SMTVerifier()
794
+ vr = verifier.verify(self._df, [proposed])
795
+ if vr.verdict == VerificationVerdict.REJECT:
796
+ return False, f"SMT verifier rejected: {vr.reason}"
797
+ if vr.verdict == VerificationVerdict.UNKNOWN:
798
+ return False, f"SMT verifier returned unknown: {vr.reason}"
799
+
800
+ return True, "Passed safety and verification"
801
+ except ImportError as exc:
802
+ return False, f"Safety/verifier dependency unavailable: {exc}"
803
+
804
+ def _auto_diagnose(self, action: Fix, issue: Issue) -> None:
805
+ """Auto-credit diagnosis when agent fixes without diagnosing first."""
806
+ already = any(
807
+ f["row"] == action.row and f["column"] == action.column for f in self._found_issues
808
+ )
809
+ if not already:
810
+ self._found_issues.append(
811
+ {"row": action.row, "column": action.column, "type": issue.issue_type}
812
+ )
813
+
814
+ def _inject_noise(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
815
+ """Apply deterministic observation noise (epsilon=0.15)."""
816
+ if not self._noise_rng:
817
+ return rows
818
+ noisy = []
819
+ for row in rows:
820
+ row_copy = dict(row)
821
+ if self._noise_rng.random() < _NOISE_EPSILON:
822
+ cols = [k for k in row_copy if k != "_row_index"]
823
+ if cols:
824
+ col = self._noise_rng.choice(cols)
825
+ val = row_copy[col]
826
+ if isinstance(val, str) and len(val) > 3:
827
+ row_copy[col] = (
828
+ val[: -(self._noise_rng.randint(1, 3))]
829
+ if self._noise_rng.random() < 0.5
830
+ else val.swapcase()
831
+ )
832
+ noisy.append(row_copy)
833
+ return noisy
834
+
835
+ def _compute_terminal(self) -> float:
836
+ """Compute terminal score."""
837
+ fixable = [i for i in self._ground_truth if i.expected is not None]
838
+ metrics = EpisodeMetrics(
839
+ found_issues=len(self._found_issues),
840
+ total_issues=len(self._ground_truth),
841
+ fixed_issues=len(self._fixed_issues),
842
+ fixable_issues=len(fixable),
843
+ false_positives=self._false_positives,
844
+ )
845
+ return self._reward_engine.compute_terminal_score(metrics)
846
+
847
+ def _error_step(self, message: str) -> StepResult:
848
+ """Build error StepResult."""
849
+ tr = ToolResult(
850
+ action_type="ERROR", success=False, error={"verdict": "error", "reason": message}
851
+ )
852
+ self._tool_history.append(tr)
853
+ self._cumulative_reward += P_INVALID
854
+ done = self._step_count >= self._max_steps
855
+ if done:
856
+ self._is_done = True
857
+ return StepResult(
858
+ observation=DataForgeObservation(
859
+ step_budget_remaining=max(0, self._max_steps - self._step_count),
860
+ tool_usage_history=list(self._tool_history[-_TOOL_HISTORY_LIMIT:]),
861
+ latest_result=tr,
862
+ done=done,
863
+ reward=P_INVALID,
864
+ cumulative_reward=self._cumulative_reward,
865
+ scratchpad_summary=self._scratchpad.summary(),
866
+ ),
867
+ reward=P_INVALID,
868
+ done=done,
869
+ )
870
+
871
+ def _terminal_result(self, reward: float) -> StepResult:
872
+ """Build terminal StepResult for already-done episodes."""
873
+ return StepResult(
874
+ observation=DataForgeObservation(
875
+ step_budget_remaining=0,
876
+ done=True,
877
+ reward=reward,
878
+ cumulative_reward=self._cumulative_reward,
879
+ scratchpad_summary=self._scratchpad.summary(),
880
+ tool_usage_history=list(self._tool_history[-_TOOL_HISTORY_LIMIT:]),
881
+ ),
882
+ reward=reward,
883
+ done=True,
884
+ )
dataforge/env/observation.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Observation builder for the DataForge RL environment.
2
+
3
+ Constructs agent-visible observations containing partial data views,
4
+ scratchpad summaries, tool results, and step budget information.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+ __all__ = ["DataForgeObservation", "ToolResult"]
14
+
15
+
16
+ class ToolResult(BaseModel):
17
+ """Result of a single tool-use action.
18
+
19
+ Args:
20
+ action_type: The action type that produced this result.
21
+ success: Whether the action succeeded.
22
+ data: Action-specific result data (rows, stats, matches, etc.).
23
+ error: Structured error information if the action failed.
24
+ """
25
+
26
+ action_type: str
27
+ success: bool = True
28
+ data: Any = None
29
+ error: dict[str, Any] | None = None
30
+
31
+ model_config = {"frozen": True}
32
+
33
+
34
+ class DataForgeObservation(BaseModel):
35
+ """Agent-visible observation returned after each environment step.
36
+
37
+ Args:
38
+ visible_rows: Dataset rows returned by INSPECT_ROWS or reset.
39
+ detector_hints: Optional hints from detectors (partial ground truth).
40
+ scratchpad_summary: Compact summary of the agent's scratchpad.
41
+ step_budget_remaining: Steps left before auto-finalize.
42
+ tool_usage_history: Last 5 tool results for context.
43
+ latest_result: Result of the most recent action.
44
+ done: Whether the episode has ended.
45
+ reward: Step reward.
46
+ cumulative_reward: Running total reward for the episode.
47
+ metadata: Additional key-value metadata.
48
+ """
49
+
50
+ visible_rows: list[dict[str, Any]] | None = None
51
+ detector_hints: list[str] | None = None
52
+ scratchpad_summary: str = ""
53
+ step_budget_remaining: int = 0
54
+ tool_usage_history: list[ToolResult] = Field(default_factory=list)
55
+ latest_result: ToolResult | None = None
56
+ done: bool = False
57
+ reward: float = 0.0
58
+ cumulative_reward: float = 0.0
59
+ metadata: dict[str, Any] = Field(default_factory=dict)
60
+
61
+ model_config = {"frozen": True}
dataforge/env/openenv_core.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv-core adapter for the DataForge RL environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from pydantic import Field
8
+
9
+ from dataforge.env.environment import DataForgeEnv
10
+
11
+ try:
12
+ from openenv.core.env_server import (
13
+ Action as OpenEnvAction,
14
+ )
15
+ from openenv.core.env_server import (
16
+ Environment as OpenEnvEnvironment,
17
+ )
18
+ from openenv.core.env_server import (
19
+ Observation as OpenEnvObservation,
20
+ )
21
+ from openenv.core.env_server import (
22
+ create_app,
23
+ )
24
+ except ImportError as exc: # pragma: no cover - exercised only without openenv extra
25
+ raise RuntimeError(
26
+ "The OpenEnv adapter requires the openenv extra: "
27
+ "pip install 'dataforge15[openenv]'."
28
+ ) from exc
29
+
30
+
31
+ class DataForgeOpenEnvAction(OpenEnvAction):
32
+ """OpenEnv action wrapper for DataForge's typed action payloads."""
33
+
34
+ action_type: str = Field(min_length=1)
35
+ row_indices: list[int] | None = None
36
+ column_names: list[str] | None = None
37
+ query: str | None = None
38
+ sql: str | None = None
39
+ test_type: str | None = None
40
+ test: str | None = None
41
+ column: str | None = None
42
+ threshold: float | None = None
43
+ pattern: str | None = None
44
+ regex: str | None = None
45
+ expect_match: bool | None = None
46
+ claim: str | None = None
47
+ affected_rows: list[int] | None = None
48
+ affected_columns: list[str] | None = None
49
+ root_cause_type: str | None = None
50
+ error_indices: list[int] | None = None
51
+ row: int | None = None
52
+ issue_type: str | None = None
53
+ new_value: str | None = None
54
+ proposed_value: str | None = None
55
+ justification: str | None = None
56
+ fix_type: str | None = None
57
+
58
+ def as_dataforge_payload(self) -> dict[str, Any]:
59
+ """Return the action payload expected by ``DataForgeEnv.step``."""
60
+ payload = self.model_dump(exclude_none=True)
61
+ payload.pop("metadata", None)
62
+ return payload
63
+
64
+
65
+ class DataForgeOpenEnvObservation(OpenEnvObservation):
66
+ """OpenEnv observation model mirroring DataForge's native observation."""
67
+
68
+ visible_rows: list[dict[str, Any]] | None = None
69
+ detector_hints: list[str] | None = None
70
+ scratchpad_summary: str = ""
71
+ step_budget_remaining: int = 0
72
+ tool_usage_history: list[dict[str, Any]] = Field(default_factory=list)
73
+ latest_result: dict[str, Any] | None = None
74
+ cumulative_reward: float = 0.0
75
+
76
+
77
+ def _to_openenv_observation(payload: dict[str, Any]) -> DataForgeOpenEnvObservation:
78
+ """Convert a native DataForge observation dictionary into OpenEnv shape."""
79
+ return DataForgeOpenEnvObservation(
80
+ visible_rows=payload.get("visible_rows"),
81
+ detector_hints=payload.get("detector_hints"),
82
+ scratchpad_summary=str(payload.get("scratchpad_summary", "")),
83
+ step_budget_remaining=int(payload.get("step_budget_remaining", 0)),
84
+ tool_usage_history=list(payload.get("tool_usage_history") or []),
85
+ latest_result=payload.get("latest_result"),
86
+ done=bool(payload.get("done", False)),
87
+ reward=payload.get("reward"),
88
+ cumulative_reward=float(payload.get("cumulative_reward", 0.0)),
89
+ metadata=dict(payload.get("metadata") or {}),
90
+ )
91
+
92
+
93
+ class DataForgeOpenEnv(OpenEnvEnvironment):
94
+ """OpenEnv-native environment wrapper."""
95
+
96
+ SUPPORTS_CONCURRENT_SESSIONS = True
97
+
98
+ def __init__(self) -> None:
99
+ super().__init__()
100
+ self._env = DataForgeEnv()
101
+ self._last_observation: DataForgeOpenEnvObservation | None = None
102
+
103
+ def reset(
104
+ self,
105
+ seed: int | None = None,
106
+ episode_id: str | None = None,
107
+ **kwargs: Any,
108
+ ) -> DataForgeOpenEnvObservation:
109
+ """Reset the wrapped DataForge environment."""
110
+ del episode_id, kwargs
111
+ result = self._env.reset(seed=seed)
112
+ observation = _to_openenv_observation(result.observation.model_dump(mode="json"))
113
+ self._last_observation = observation
114
+ return observation
115
+
116
+ def step(
117
+ self,
118
+ action: DataForgeOpenEnvAction,
119
+ timeout_s: float | None = None,
120
+ **kwargs: Any,
121
+ ) -> DataForgeOpenEnvObservation:
122
+ """Step the wrapped DataForge environment."""
123
+ del timeout_s, kwargs
124
+ result = self._env.step(action.as_dataforge_payload())
125
+ observation = _to_openenv_observation(result.observation.model_dump(mode="json"))
126
+ self._last_observation = observation
127
+ return observation
128
+
129
+ def state(self) -> DataForgeOpenEnvObservation:
130
+ """Return the latest observation or reset lazily."""
131
+ if self._last_observation is None:
132
+ return self.reset()
133
+ return self._last_observation
134
+
135
+ def close(self) -> None:
136
+ """Close the wrapped environment."""
137
+ self._env.close()
138
+
139
+
140
+ app = create_app(
141
+ DataForgeOpenEnv,
142
+ DataForgeOpenEnvAction,
143
+ DataForgeOpenEnvObservation,
144
+ env_name="dataforge-env",
145
+ max_concurrent_envs=64,
146
+ )
dataforge/env/reward.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward engine for the DataForge RL environment.
2
+
3
+ All constants and formulas are derived bit-for-bit from REWARD_DESIGN.md.
4
+
5
+ Terminal score: detection_rate * 0.40 + fix_rate * 0.60 - false_positives * fp_rate
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+
12
+ __all__ = [
13
+ "DETECTION_WEIGHT",
14
+ "FALSE_POS_PENALTY_RATE",
15
+ "FIX_WEIGHT",
16
+ "LATE_STEP_THRESHOLD",
17
+ "P_FALSE_POS",
18
+ "P_INVALID",
19
+ "P_LATE_STEP",
20
+ "P_REINSPECT",
21
+ "P_WRONG_FIX",
22
+ "R_DIAGNOSE",
23
+ "R_EXPLORE",
24
+ "R_FIX",
25
+ "R_FIX_PARTIAL",
26
+ "R_JUSTIFY_BONUS",
27
+ "R_ROOT_CAUSE",
28
+ "R_TYPE_BONUS",
29
+ "SPAM_THRESHOLD",
30
+ "EpisodeMetrics",
31
+ "RewardEngine",
32
+ ]
33
+
34
+ # Positive rewards
35
+ R_DIAGNOSE: float = 0.10
36
+ R_TYPE_BONUS: float = 0.05
37
+ R_FIX: float = 0.15
38
+ R_FIX_PARTIAL: float = 0.075
39
+ R_JUSTIFY_BONUS: float = 0.05
40
+ R_EXPLORE: float = 0.01
41
+ R_ROOT_CAUSE: float = 0.10
42
+
43
+ # Negative penalties
44
+ P_FALSE_POS: float = -0.05
45
+ P_WRONG_FIX: float = -0.08
46
+ P_LATE_STEP: float = -0.02
47
+ P_INVALID: float = -0.01
48
+ P_REINSPECT: float = -0.01
49
+
50
+ # Thresholds
51
+ LATE_STEP_THRESHOLD: float = 0.80
52
+ DETECTION_WEIGHT: float = 0.40
53
+ FIX_WEIGHT: float = 0.60
54
+ FALSE_POS_PENALTY_RATE: float = 0.05
55
+ SPAM_THRESHOLD: float = 2.0
56
+
57
+
58
+ @dataclass
59
+ class EpisodeMetrics:
60
+ """Accumulated metrics for terminal score computation."""
61
+
62
+ found_issues: int = 0
63
+ total_issues: int = 0
64
+ fixed_issues: int = 0
65
+ fixable_issues: int = 0
66
+ false_positives: int = 0
67
+
68
+ @property
69
+ def total_diagnoses(self) -> int:
70
+ """Total diagnosis attempts (correct + incorrect)."""
71
+ return self.found_issues + self.false_positives
72
+
73
+
74
+ class RewardEngine:
75
+ """Computes dense per-step and terminal rewards."""
76
+
77
+ def compute_terminal_score(self, metrics: EpisodeMetrics) -> float:
78
+ """Compute terminal score per REWARD_DESIGN.md formula."""
79
+ if metrics.total_issues == 0:
80
+ return 0.0
81
+ detection_rate = metrics.found_issues / metrics.total_issues
82
+ fix_rate = (
83
+ metrics.fixed_issues / metrics.fixable_issues if metrics.fixable_issues > 0 else 0.0
84
+ )
85
+ fp_rate = FALSE_POS_PENALTY_RATE
86
+ if (
87
+ metrics.total_issues > 0
88
+ and metrics.total_diagnoses > SPAM_THRESHOLD * metrics.total_issues
89
+ ):
90
+ fp_rate *= 2.0
91
+ penalty = metrics.false_positives * fp_rate
92
+ raw = detection_rate * DETECTION_WEIGHT + fix_rate * FIX_WEIGHT - penalty
93
+ return round(max(0.0, min(1.0, raw)), 4)
94
+
95
+ def compute_late_penalty(self, step: int, max_steps: int) -> float:
96
+ """Return P_LATE_STEP if past 80% budget, else 0.0."""
97
+ threshold = int(max_steps * LATE_STEP_THRESHOLD)
98
+ return P_LATE_STEP if step > threshold else 0.0
99
+
100
+ def compute_exploration_bonus(
101
+ self,
102
+ new_row_indices: set[int],
103
+ inspected_rows: set[int],
104
+ total_rows: int,
105
+ ground_truth_rows: set[int],
106
+ found_issue_rows: set[int],
107
+ ) -> float:
108
+ """Compute exploration bonus for newly-inspected rows."""
109
+ if not new_row_indices:
110
+ return P_REINSPECT
111
+ undiscovered = sum(
112
+ 1 for r in new_row_indices if r in ground_truth_rows and r not in found_issue_rows
113
+ )
114
+ bonus = undiscovered * R_EXPLORE
115
+ if total_rows > 0:
116
+ all_inspected = inspected_rows | new_row_indices
117
+ coverage_ratio = len(all_inspected) / total_rows
118
+ bonus += len(new_row_indices) * R_EXPLORE * 0.5 * (1.0 - coverage_ratio)
119
+ return bonus
120
+
121
+ def diagnose_reward(self, type_match: bool) -> float:
122
+ """Reward for correct diagnosis."""
123
+ return R_DIAGNOSE + (R_TYPE_BONUS if type_match else 0.0)
124
+
125
+ def fix_reward(self, exact: bool, has_justification: bool) -> float:
126
+ """Reward for correct fix."""
127
+ reward = R_FIX if exact else R_FIX_PARTIAL
128
+ return reward + (R_JUSTIFY_BONUS if has_justification else 0.0)
dataforge/env/server.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server for the DataForge RL environment.
2
+
3
+ Provides OpenEnv-compatible HTTP endpoints:
4
+ POST /reset — Start a new episode
5
+ POST /step — Execute an action
6
+ GET /state — Return current state snapshot
7
+ POST /close — No-op shutdown
8
+ GET /health — Liveness check
9
+ GET /metadata — Environment metadata
10
+ GET /schema — Action/observation JSON schemas
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ import os
17
+ from threading import RLock
18
+ from typing import Any
19
+
20
+ from fastapi import FastAPI, HTTPException, Request
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+ from pydantic import TypeAdapter
23
+
24
+ from dataforge.agent.tool_actions import Action
25
+ from dataforge.env.environment import DataForgeEnv, EnvState
26
+ from dataforge.env.observation import DataForgeObservation
27
+ from dataforge.http.problem import problem_exception_handler
28
+ from dataforge.observability import configure_fastapi_observability
29
+
30
+ logger = logging.getLogger("dataforge.env.server")
31
+
32
+
33
+ def _build_cors_origins() -> list[str]:
34
+ """Build the explicit OpenEnv CORS allowlist from the environment."""
35
+ raw_origins = os.environ.get("DATAFORGE_OPENENV_ORIGINS", "")
36
+ return [origin.strip() for origin in raw_origins.split(",") if origin.strip()]
37
+
38
+
39
+ def _build_cors_origin_regex() -> str | None:
40
+ """Allow local browser development only when explicitly enabled."""
41
+ if os.environ.get("DATAFORGE_OPENENV_DEV") != "1":
42
+ return None
43
+ return r"^http://(?:localhost|127(?:\.\d{1,3}){3})(?::\d+)?$"
44
+
45
+
46
+ app = FastAPI(
47
+ title="DataForge Environment",
48
+ description="OpenEnv-compatible RL environment for data-quality repair.",
49
+ version="0.1.0",
50
+ )
51
+
52
+ app.add_middleware(
53
+ CORSMiddleware,
54
+ allow_origins=_build_cors_origins(),
55
+ allow_origin_regex=_build_cors_origin_regex(),
56
+ allow_credentials=False,
57
+ allow_methods=["GET", "POST", "OPTIONS"],
58
+ allow_headers=["*"],
59
+ )
60
+ app.add_exception_handler(HTTPException, problem_exception_handler)
61
+ configure_fastapi_observability(app, service_name="dataforge-openenv")
62
+
63
+ _registry_lock = RLock()
64
+ _default_env = DataForgeEnv()
65
+ _sessions: dict[str, DataForgeEnv] = {}
66
+
67
+
68
+ def _get_env(episode_id: str | None) -> DataForgeEnv:
69
+ """Resolve an environment by episode id, preserving legacy no-id behavior."""
70
+ if not episode_id:
71
+ return _default_env
72
+ with _registry_lock:
73
+ try:
74
+ return _sessions[episode_id]
75
+ except KeyError as exc:
76
+ raise HTTPException(
77
+ status_code=404,
78
+ detail={"error": "episode_not_found", "episode_id": episode_id},
79
+ ) from exc
80
+
81
+
82
+ def _remember_env(env: DataForgeEnv, episode_id: str) -> None:
83
+ """Register a session and update the legacy default environment."""
84
+ global _default_env
85
+ with _registry_lock:
86
+ _sessions[episode_id] = env
87
+ _default_env = env
88
+
89
+
90
+ @app.post("/reset")
91
+ async def reset(seed: int | None = None) -> dict[str, Any]:
92
+ """Reset the environment for a new episode."""
93
+ env = DataForgeEnv()
94
+ result = env.reset(seed=seed)
95
+ episode_id = str(result.info["episode_id"])
96
+ _remember_env(env, episode_id)
97
+ return result.model_dump(mode="json")
98
+
99
+
100
+ @app.post("/step")
101
+ async def step(action: dict[str, Any]) -> dict[str, Any]:
102
+ """Execute one agent action."""
103
+ action_payload = dict(action)
104
+ raw_episode_id = action_payload.pop("episode_id", None)
105
+ episode_id = str(raw_episode_id) if raw_episode_id else None
106
+ result = _get_env(episode_id).step(action_payload)
107
+ return result.model_dump(mode="json")
108
+
109
+
110
+ @app.get("/state")
111
+ async def state(episode_id: str | None = None) -> dict[str, Any]:
112
+ """Return current environment state snapshot."""
113
+ result = _get_env(episode_id).state()
114
+ return result.model_dump(mode="json")
115
+
116
+
117
+ @app.post("/close")
118
+ async def close(request: Request, episode_id: str | None = None) -> dict[str, Any]:
119
+ """No-op close endpoint for OpenEnv compatibility."""
120
+ body_episode_id: str | None = None
121
+ if episode_id is None:
122
+ try:
123
+ payload = await request.json()
124
+ except Exception:
125
+ payload = None
126
+ if isinstance(payload, dict) and payload.get("episode_id"):
127
+ body_episode_id = str(payload["episode_id"])
128
+
129
+ target_episode_id = episode_id or body_episode_id
130
+ env = _get_env(target_episode_id)
131
+ env.close()
132
+ if target_episode_id:
133
+ with _registry_lock:
134
+ _sessions.pop(target_episode_id, None)
135
+ return {"status": "closed", "episode_id": target_episode_id}
136
+
137
+
138
+ @app.get("/health")
139
+ async def health() -> dict[str, Any]:
140
+ """Liveness check."""
141
+ return {"status": "healthy", "environment": "dataforge-env"}
142
+
143
+
144
+ @app.get("/metadata")
145
+ async def metadata() -> dict[str, Any]:
146
+ """Environment metadata for OpenEnv discovery."""
147
+ return {
148
+ "name": "dataforge-env",
149
+ "version": "0.1.0",
150
+ "description": (
151
+ "DataForge RL Environment — agents learn to detect, diagnose, "
152
+ "and repair data-quality issues in tabular datasets."
153
+ ),
154
+ "action_types": [
155
+ "INSPECT_ROWS",
156
+ "SQL_QUERY",
157
+ "STAT_TEST",
158
+ "PATTERN_MATCH",
159
+ "HYPOTHESIS",
160
+ "ROOT_CAUSE",
161
+ "DIAGNOSE",
162
+ "FIX",
163
+ ],
164
+ }
165
+
166
+
167
+ @app.get("/schema")
168
+ async def schema() -> dict[str, Any]:
169
+ """Return JSON schemas for action and observation models."""
170
+ action_adapter: TypeAdapter[Action] = TypeAdapter(Action)
171
+ return {
172
+ "action": action_adapter.json_schema(),
173
+ "observation": DataForgeObservation.model_json_schema(),
174
+ "state": EnvState.model_json_schema(),
175
+ }
dataforge/evaluation_contract.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public evaluation evidence models for DataForge repair releases."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ from typing import Any, Literal
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+ InferabilityLabel = Literal[
12
+ "deterministic_normalization",
13
+ "context_derivable",
14
+ "external_reference_required",
15
+ "not_inferable_from_prompt",
16
+ ]
17
+ PROMOTION_SLICE: InferabilityLabel = "deterministic_normalization"
18
+ ABSTENTION_SLICES = frozenset({"external_reference_required", "not_inferable_from_prompt"})
19
+ AUXILIARY_SLICES = frozenset({"context_derivable"})
20
+ PromotionStatus = Literal[
21
+ "diagnostic_only",
22
+ "diagnostic_promoted",
23
+ "quality_improved_verified",
24
+ "public_quality_milestone",
25
+ "rejected",
26
+ ]
27
+
28
+
29
+ class EvaluationTaskV2(BaseModel):
30
+ """One auditable, source-stable model grading task.
31
+
32
+ Ground truth is retained for local grading but excluded from normal JSON
33
+ serialization so prompts and public reports cannot accidentally leak labels.
34
+ """
35
+
36
+ schema_version: Literal["evaluation_task_v2"] = "evaluation_task_v2"
37
+ task_id: str = Field(min_length=1)
38
+ prompt_hash: str = Field(min_length=64, max_length=64)
39
+ dataset_sha: str = Field(min_length=1)
40
+ split_id: str = Field(min_length=1)
41
+ inferability: InferabilityLabel
42
+ prompt: dict[str, Any]
43
+ allowed_columns: list[str] = Field(min_length=1)
44
+ valid_rows: list[int] = Field(min_length=1)
45
+ provenance: dict[str, Any]
46
+ hidden_ground_truth: list[dict[str, Any]] = Field(default_factory=list, exclude=True)
47
+
48
+ model_config = {"frozen": True}
49
+
50
+
51
+ class ReleaseEvidenceV2(BaseModel):
52
+ """Serializable release-gate evidence for model and benchmark promotion."""
53
+
54
+ schema_version: Literal["release_evidence_v2"] = "release_evidence_v2"
55
+ model_repo: str = Field(min_length=1)
56
+ model_sha: str = Field(min_length=1)
57
+ dataset_repo: str = Field(min_length=1)
58
+ dataset_sha: str = Field(min_length=1)
59
+ strict_macro_f1: float = Field(ge=0.0, le=1.0)
60
+ canonicalized_macro_f1: float = Field(ge=0.0, le=1.0)
61
+ parse_success_rate: float = Field(ge=0.0, le=1.0)
62
+ schema_case_error_count: int = Field(ge=0)
63
+ promotion_slice: InferabilityLabel = PROMOTION_SLICE
64
+ slice_scores: dict[InferabilityLabel, dict[str, float | int]] = Field(default_factory=dict)
65
+ inferability_slice_scores: dict[InferabilityLabel, float] = Field(default_factory=dict)
66
+ package_versions: dict[str, str] = Field(default_factory=dict)
67
+ promotion_status: PromotionStatus
68
+ gate_failures: list[str] = Field(default_factory=list)
69
+
70
+ model_config = {"frozen": True}
71
+
72
+
73
+ def prompt_sha256(prompt: dict[str, Any]) -> str:
74
+ """Hash a prompt payload with stable JSON serialization."""
75
+ encoded = json.dumps(prompt, sort_keys=True, separators=(",", ":")).encode("utf-8")
76
+ return hashlib.sha256(encoded).hexdigest()
dataforge/fixtures/hospital_10rows.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ provider_number,hospital_name,city,state,zip_code,phone_number,rating,mortality_rate,readmission_rate,er_wait_time
2
+ PRV001,General Hospital,Springfield,IL,62701,2175550101,4.2,0.023,0.145,28
3
+ PRV002,St. Mary Medical Center,Chicago,IL,60601,3125550202,3.8,0.031,0.162,35
4
+ PRV001,Springfield Medical,Springfield,IL,62701,2175550303,4.5,0.019,0.138,22
5
+ PRV003,Mercy Hospital,Peoria,IL,61602,3095550404,3.5,0.028,0.158,31
6
+ PRV004,Northwestern Memorial,Chicago,IL,60611,not available,4.1,0.025,0.149,26
7
+ PRV005,Rush University MC,Chicago,IL,60612,3125550606,45.0,0.022,0.141,29
8
+ PRV006,Advocate Christ,Oak Lawn,IL,60453,7085550707,3.9,0.027,0.155,33
9
+ PRV007,Loyola University MC,Maywood,IL,60153,7085550808,4.3,0.020,0.142,25
10
+ PRV008,Presence St. Joseph,Joliet,IL,60435,8155550909,4.0,0.026,0.151,30
11
+ PRV009,Edward Hospital,Naperville,IL,60540,6305551010,3.7,0.029,0.160,34
dataforge/fixtures/hospital_schema.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hospital dataset schema for DataForge profile command.
2
+
3
+ columns:
4
+ provider_number: str
5
+ hospital_name: str
6
+ city: str
7
+ state: str
8
+ zip_code: str
9
+ phone_number: str
10
+ rating: float
11
+ mortality_rate: float
12
+ readmission_rate: float
13
+ er_wait_time: int
14
+
15
+ functional_dependencies:
16
+ - determinant: [provider_number]
17
+ dependent: hospital_name
dataforge/http/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """HTTP helpers shared by DataForge backend surfaces."""
dataforge/http/problem.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RFC 9457 problem details helpers for FastAPI surfaces."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Mapping
6
+ from typing import Any
7
+
8
+ from fastapi import HTTPException, Request
9
+ from fastapi.responses import JSONResponse
10
+ from pydantic import BaseModel, ConfigDict, Field
11
+
12
+
13
+ class ProblemDetail(BaseModel):
14
+ """RFC 9457 problem detail response with extension members."""
15
+
16
+ type: str = Field(default="about:blank")
17
+ title: str
18
+ status: int
19
+ detail: str
20
+ instance: str | None = None
21
+
22
+ model_config = ConfigDict(strict=True, extra="allow")
23
+
24
+
25
+ def problem_body(
26
+ *,
27
+ status: int,
28
+ title: str,
29
+ detail: str,
30
+ type_: str = "about:blank",
31
+ instance: str | None = None,
32
+ **extensions: Any,
33
+ ) -> dict[str, Any]:
34
+ """Build a problem details JSON object."""
35
+ body = ProblemDetail(
36
+ type=type_,
37
+ title=title,
38
+ status=status,
39
+ detail=detail,
40
+ instance=instance,
41
+ **extensions,
42
+ )
43
+ return body.model_dump(mode="json", exclude_none=True)
44
+
45
+
46
+ def problem_response(
47
+ *,
48
+ status: int,
49
+ title: str,
50
+ detail: str,
51
+ type_: str = "about:blank",
52
+ instance: str | None = None,
53
+ headers: Mapping[str, str] | None = None,
54
+ **extensions: Any,
55
+ ) -> JSONResponse:
56
+ """Return an RFC 9457 JSON response."""
57
+ return JSONResponse(
58
+ status_code=status,
59
+ content=problem_body(
60
+ status=status,
61
+ title=title,
62
+ detail=detail,
63
+ type_=type_,
64
+ instance=instance,
65
+ **extensions,
66
+ ),
67
+ headers=headers,
68
+ media_type="application/problem+json",
69
+ )
70
+
71
+
72
+ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
73
+ """Normalize FastAPI HTTPException values into problem details."""
74
+ raw_detail = exc.detail
75
+ extensions: dict[str, Any] = {}
76
+ if isinstance(raw_detail, dict):
77
+ error_code = str(raw_detail.get("error", "http_error"))
78
+ message = str(raw_detail.get("message") or raw_detail.get("detail") or error_code)
79
+ extensions.update(raw_detail)
80
+ else:
81
+ error_code = "http_error"
82
+ message = str(raw_detail)
83
+
84
+ return problem_response(
85
+ status=exc.status_code,
86
+ type_=f"https://dataforge.local/problems/{error_code}",
87
+ title=error_code.replace("_", " ").title(),
88
+ detail=message,
89
+ instance=str(request.url.path),
90
+ headers=exc.headers,
91
+ **extensions,
92
+ )
93
+
94
+
95
+ async def problem_exception_handler(request: Request, exc: Exception) -> JSONResponse:
96
+ """Adapter with the broad exception signature Starlette expects."""
97
+ if isinstance(exc, HTTPException):
98
+ return await http_exception_handler(request, exc)
99
+ raise exc
dataforge/integrations/dbt.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """The dbt integration lives in the separate ``dataforge15-dbt`` package."""
dataforge/observability.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optional OpenTelemetry hooks for DataForge backend surfaces."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from collections.abc import Iterator
7
+ from contextlib import contextmanager, nullcontext
8
+ from importlib import import_module
9
+ from typing import Any
10
+
11
+ _SENSITIVE_ATTR_FRAGMENTS = ("authorization", "cookie", "token", "key", "secret", "password")
12
+
13
+
14
+ def _otel_enabled() -> bool:
15
+ """Return whether optional OpenTelemetry instrumentation is enabled."""
16
+ return os.environ.get("DATAFORGE_OTEL_ENABLED", "").strip().lower() in {
17
+ "1",
18
+ "true",
19
+ "yes",
20
+ "on",
21
+ }
22
+
23
+
24
+ def _safe_attrs(attributes: dict[str, Any]) -> dict[str, str | int | float | bool]:
25
+ """Keep only scalar, non-sensitive telemetry attributes."""
26
+ safe: dict[str, str | int | float | bool] = {}
27
+ for key, value in attributes.items():
28
+ lowered = key.lower()
29
+ if any(fragment in lowered for fragment in _SENSITIVE_ATTR_FRAGMENTS):
30
+ continue
31
+ if lowered in {"row_values", "rows", "payload", "source_bytes", "csv"}:
32
+ continue
33
+ if isinstance(value, str | int | float | bool):
34
+ safe[key] = value
35
+ return safe
36
+
37
+
38
+ def configure_fastapi_observability(app: Any, *, service_name: str) -> bool:
39
+ """Instrument a FastAPI app when OpenTelemetry is explicitly enabled."""
40
+ if not _otel_enabled():
41
+ return False
42
+ try:
43
+ fastapi_instrumentation = import_module("opentelemetry.instrumentation.fastapi")
44
+ trace_module = import_module("opentelemetry.trace")
45
+ except ImportError:
46
+ return False
47
+
48
+ app.state.dataforge_service_name = service_name
49
+ fastapi_instrumentation.FastAPIInstrumentor.instrument_app(
50
+ app,
51
+ tracer_provider=trace_module.get_tracer_provider(),
52
+ excluded_urls="/api/docs,/docs,/redoc,/openapi.json",
53
+ )
54
+ return True
55
+
56
+
57
+ @contextmanager
58
+ def repair_stage_span(stage: str, **attributes: Any) -> Iterator[None]:
59
+ """Create a repair-stage span when OpenTelemetry is available."""
60
+ if not _otel_enabled():
61
+ with nullcontext():
62
+ yield
63
+ return
64
+
65
+ try:
66
+ trace_module = import_module("opentelemetry.trace")
67
+ except ImportError:
68
+ with nullcontext():
69
+ yield
70
+ return
71
+
72
+ tracer = trace_module.get_tracer("dataforge.repair")
73
+ with tracer.start_as_current_span(stage) as span:
74
+ for key, value in _safe_attrs(attributes).items():
75
+ span.set_attribute(key, value)
76
+ yield
dataforge/py.typed ADDED
@@ -0,0 +1 @@
 
 
1
+
dataforge/release/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Release verification helpers for DataForge."""
2
+