diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..b8bd5438db31207e38d2b337007e7024aaf14c5c --- /dev/null +++ b/.env.example @@ -0,0 +1,16 @@ +# Copy to .env and fill in. .env is gitignored; .env.example is committed. +# Loaded automatically by zsgdp.config.load_env_file() when CLI / app starts. + +# Hugging Face Hub access token. Required for gated models like jina-v3 +# (the embedding retriever) and any private model id used in gpu.models. +# Read transparently by transformers / sentence-transformers when set. +HF_TOKEN= + +# Logging — see zsgdp/logging_config.py. +# ZSGDP_LOG_LEVEL=INFO +# ZSGDP_LOG_JSON=1 + +# Pipeline overrides. +# ZSGDP_CONFIG_PATH=configs/docling.yaml +# ZSGDP_MAX_UPLOAD_BYTES=52428800 +# ZSGDP_MAX_PAGE_COUNT=200 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..166ba27e461850bd2f20df384480079aff55a57f --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +__pycache__/ +*.py[cod] +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +.venv/ +venv/ +out/ +parsed/ +benchmarks/results/ + +# Secrets — never commit. Loaded by zsgdp.config.load_env_file() at runtime. +.env +.env.* +!.env.example diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..376bcf8bde26a138cf6302cc25cb584d4757b678 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,44 @@ +# Pre-commit and pre-push hooks for zeroshotGPU. +# +# Install once with: +# python -m pip install pre-commit +# pre-commit install --install-hooks --hook-type pre-commit --hook-type pre-push +# +# pre-commit runs only fast static checks on every commit so the developer +# loop stays tight. The slow `preflight` runs at pre-push time so it gates +# what reaches the remote without slowing down individual commits. + +default_language_version: + python: python3.11 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + stages: [pre-commit] + - id: end-of-file-fixer + stages: [pre-commit] + - id: check-yaml + stages: [pre-commit] + # The simple YAML in configs/*.yaml uses a tiny subset; check-yaml + # is fine. `app_file` etc. in README.md aren't real YAML headers + # — they're HF Spaces front-matter and excluded from this hook. + exclude: ^README\.md$ + - id: check-json + stages: [pre-commit] + - id: check-added-large-files + stages: [pre-commit] + args: ["--maxkb=2048"] + - id: check-merge-conflict + stages: [pre-commit] + + - repo: local + hooks: + - id: zsgdp-preflight + name: zsgdp preflight (unit + regression + space-check + parsers) + entry: python -m zsgdp.cli preflight --root . + language: system + pass_filenames: false + stages: [pre-push] + always_run: true diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000000000000000000000000000000000..651fe1abdb9dc1eed8d3335b5db469fd02394ec8 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,274 @@ +# Changelog + +All notable changes to zeroshotGPU. Format follows +[Keep a Changelog](https://keepachangelog.com/en/1.1.0/); versions follow +[Semantic Versioning](https://semver.org/spec/v2.0.0.html) but the project is +pre-1.0 so minor bumps may include breaking changes. + +## [Unreleased] + +### Documentation — README restructured + +- Reorganised into Install → Quick start → Opt-ins → Outputs → + Architecture map → Production benchmark numbers → Deployment → + Contributing. +- New "Production benchmark numbers" placeholder table with §29 + success criteria recalled inline; columns are + `Metric / Dataset / Value / Date / Run` so the operator pastes real + numbers in after running `make space-smoke` and `make benchmark` + on the Space. +- Optional-extras table (`embedding`, `gpu_repair`, `spaces`) + documents what each extra adds and the config flag that requires it. +- Architecture quick-map turned into a table; one row per top-level + module with its responsibility. +- Deployment section is now a numbered checklist that ends with + "update the production-benchmark table." + +### Added — Space smoke validation runner + +- `scripts/run_space_smoke.py` automates the five smokes documented in + `docs/space_smoke.md`. One command runs whichever smokes have their + deps installed; missing deps surface as `skip` results with explicit + `pip install` hints, not crashes. +- Five smokes: `lexical` (model-free benchmark), `ablation` (per-parser + runner), `embedding` (sentence-transformers + jina-v3 lazy-load + probe), `gpu_repair` (dry-run plan + repair-loop iteration check — + *does not* download multi-GB Qwen weights, defers live invocation + to `run-gpu-tasks --execute`), `marker` (binary detection + + registry availability). +- `--strict` mode treats skipped smokes as failures; `--output PATH` + emits a structured JSON report with per-smoke `detail`, elapsed + seconds, status (`pass`/`fail`/`skip`/`error`), and aggregate + summary counts. +- 14 new tests covering registry membership, report aggregation, + text formatting per status, strict-mode skip-as-failure, end-to-end + smoke execution for the three model-free smokes, and skip-path + structure for the model-dependent ones. + +### Added — per-artifact downloads in the Space UI + +- New "Artifacts" tab in `app.py` exposes each top-level artifact + (`parsed_document.json`, `document.md`, `chunks.jsonl`, + `quality_report.json`, etc. — 16 candidate files) as an individual + download via `gr.Files`. The bundled zip stays as it was for + archival, and nested asset dirs (`assets/pages/*.png`, + `assets/tables/*.png`) are intentionally excluded from the + per-artifact list — they can be large and the zip already covers + them. +- The artifact list is built from `_INDIVIDUAL_ARTIFACT_NAMES` in + declaration order so the UI listing is stable across runs. Missing + files are silently skipped (different parses emit different subsets; + e.g. `conflict_report.json` only when multiple parsers ran). +- All return paths in `parse_uploaded_document` now go through a + single `_empty_outputs(...)` helper so the tuple width can't drift + between success and the four error paths. New drift-guard test + asserts `len(outputs) == 11` for every error path. +- Summary JSON now includes `individual_artifact_count`. + +### Added — CLI help with examples + +- Each non-trivial CLI subcommand (`parse`, `parse-folder`, `space-check`, + `run-gpu-tasks`, `benchmark`, `benchmark-ablate`, `preflight`, + `combine-benchmarks`, `export-chunks`, `validate-artifacts`, plus the + top-level help) now ships with an `Examples:` block in its `--help` + output. Multi-line shell snippets render via + `argparse.RawDescriptionHelpFormatter` + a textwrap-dedent helper so + the source-side indentation doesn't leak into the rendered output. +- `zsgdp run-gpu-tasks --help` now explicitly contrasts the dry-run + default against `--execute`, matching the safety contract of + `repair.execute_gpu_escalations` in config. +- 9 new tests guarding: epilog dedent helper, blank-line preservation + in epilogs, top-level help lists examples, and per-subcommand + examples cover their distinguishing flags (e.g. `benchmark` shows + all three dataset modes; `combine-benchmarks` shows label pairing). + +### Added — contributor onboarding + +- `CONTRIBUTING.md` documenting setup, hooks, test layout, fixture + format, parser/metric/schema-bump procedure, logging conventions, + PR checklist, and an architecture quick-map. +- `.pre-commit-config.yaml` with two stages: + - **pre-commit**: trailing whitespace, end-of-file fixer, JSON/YAML + syntax, large-file guard (2 MB cap), merge-conflict markers. + - **pre-push**: runs `python -m zsgdp.cli preflight` so failing + preflight blocks the push. External hook repo is pinned to a + specific tag (no `master`/`HEAD` references). +- `tests/test_repo_hygiene.py` (6 tests) — guards `.env` is in + `.gitignore`, `.env.example` is committed and contains no + real-shape secrets, pre-commit config has the preflight hook on + the pre-push stage with a pinned external repo, `CONTRIBUTING.md` + references the preflight workflow and Space smoke checklist, + `CHANGELOG.md` has an `[Unreleased]` section. + +### Added — performance baselines + +- Regression fixture format gains an optional `performance` block: + `repeats`, `max_elapsed_seconds`, `min_pages_per_second`, + `always_enforce`. The runner parses each fixture N times and compares + the median against the floor — the cold-import outlier on the first + run is stripped automatically. +- Default opt-in via `ZSGDP_REGRESSION_PERF=1`; per-fixture override + via `always_enforce: true`. Floors are intended as + catastrophic-regression guards, not tight perf bars. +- Seed fixture `markdown_basic` ships with a 2.0s / 0.5pps floor + (~80x slack against measured ~6ms median) so it exercises the path + without flaking on slow CI. +- 5 new unit tests for the perf evaluator: max-elapsed and + min-pps trip correctly, median strips cold outliers, env-var gating + honours `always_enforce`. + +### Added — preflight + secrets + +- Preflight runner: `zsgdp preflight` CLI subcommand and `make preflight` + target. Chains `unittest discover`, regression fixtures, `space-check`, + and `parsers` registry sanity. `--benchmark` adds an end-to-end smoke + against the regression fixtures. Each step's output is suppressed on + success and surfaced on failure; one-line summary always printed. +- `Makefile` with targets `test`, `regression`, `space-check`, `parsers`, + `preflight`, `preflight-full`, `benchmark`, `clean`. +- `.env` loading via `zsgdp.config.load_env_file()`. Read at CLI start + and `app.py` import; pre-set environment variables always win. Never + overrides Space-side secrets. `.env.example` shipped as the template. +- `.env`/`.env.*` added to `.gitignore` (`.env.example` whitelisted). +- `zsgdp.config.hf_token()` resolves `HF_TOKEN`, + `HUGGING_FACE_HUB_TOKEN`, `HUGGINGFACE_TOKEN` in priority order. + +### Added — structured logging + +- `zsgdp.logging_config` with idempotent `configure_logging()`. Default + level WARNING; opt in via `ZSGDP_LOG_LEVEL`. Optional one-line JSON + records via `ZSGDP_LOG_JSON=1`; structured `extra={...}` fields + promoted to top-level keys for HF Spaces logs / Loki / Datadog. +- Wired into pipeline (`parse_start`, `parser_candidate`, + `parser_failed`, `parse_end`), repair controller (`repair_iteration`), + GPU worker (`gpu_task_executed`, `gpu_task_blocked`), CLI, and + `app.py`. App auto-enables JSON mode when `SPACE_ID` is set. + +### Added — deployment-readiness pass + +- Pinned upper bounds on all `requirements.txt` and `pyproject.toml` + dependencies. Added explicit `embedding` and `gpu_repair` extras so the + optional sentence-transformers / transformers stacks can be installed + without dragging the whole spaces extra in. +- Abuse / cost guards in the Gradio Space entrypoint (`app.py`): + `MAX_UPLOAD_BYTES` (50 MB default) and `MAX_PAGE_COUNT` (200 default), + both overridable via `ZSGDP_MAX_*` env vars. Oversized uploads are + rejected with a clear UI error before parsing starts. +- `SCHEMA_VERSION` constant and `ParsedDocument.schema_version` field. + Surfaced into the artifact manifest as + `parsed_document_schema_version` alongside the existing manifest + `schema_version`. Validation report echoes both so consumers can gate. +- Regression fixture format under `tests/regression/`: a YAML-style + `*.expected.json` tolerance spec paired with an input document. Runner + auto-discovers, asserts on tolerances (counts, score, markdown + contains/excludes, repair/disagreement rate ranges). One seed fixture + shipped (`markdown_basic`). + +### Added — eval surface + +- Per-parser GT-comparison metrics within a single merged run + (`zsgdp/benchmarks/per_parser_metrics.py`). Reads pre-merge candidate + snapshots from `parsed.provenance.candidates` and computes layout F1 / + table structure / formula CER per parser against the same GT. + Surfaced as `per_parser_metrics.csv` and per-doc field + `per_parser_metrics`. +- Per-parser cross-doc leaderboard rollup + (`per_parser_gt_leaderboard.csv`) with truth-aware filtering: a metric + contributes to a parser's mean only when that parser was actually + evaluated against truths for that metric on that document. +- Cross-dataset comparison (`zsgdp/benchmarks/cross_dataset.py`) with + `combine-benchmarks` CLI subcommand. Combines multiple + `results.json` summaries into `dataset_summary.csv` and a + parser-vs-dataset matrix. Missing metrics surface as `None` rather + than 0.0 so callers can distinguish absent from true-zero. +- Embedding-based retriever (`zsgdp/benchmarks/embedding_retriever.py`) + satisfying the `Retriever` protocol. Defaults to lexical (model-free, + CI-safe); opt in via `benchmarks.retriever.backend=embedding` in + config. Lazy-loads `sentence-transformers` on first use; falls back + cleanly when unavailable. +- Layout F1 against ground-truth bbox annotations + (`zsgdp/verify/layout_f1.py`). Class-aware and class-agnostic scores + side-by-side, per-category breakdown. DocLayNet COCO and OmniDocBench + JSON adapters in `zsgdp/benchmarks/ground_truth.py`. +- Table structure similarity (`zsgdp/verify/table_structure.py`): + shape similarity × multiset cell-content F1, greedy bipartite + matching. +- Formula extraction CER (`zsgdp/verify/formula_extraction.py`): + Levenshtein-based, normalized for whitespace and `$`/`$$` delimiters. +- Retrieval-readiness metrics (`zsgdp/verify/retrieval.py`): recall@k, + citation accuracy@k, mean reciprocal rank. Synthetic QA generator + (`zsgdp/benchmarks/retrieval.py`) using distinctive sentences. +- Parser-disagreement rate + (`zsgdp/verify/parser_disagreement.py`): conflict count over parser + pair count from the merger's existing conflict report. +- Repair success / regression rates + (`zsgdp/verify/repair_success.py`): pre/post issue identity diff; + iteration history, score delta, action counts. +- Parser contribution counts: which parser's elements survived the + merge, surfaced as per-doc and aggregate fractions. +- Parser ablation runner (`zsgdp/benchmarks/ablation_runner.py`) with + `benchmark-ablate` CLI subcommand. Runs the benchmark once per parser + in isolation plus a merged arm, emits a comparison CSV. +- Three dataset loaders (`zsgdp/benchmarks/datasets.py`): + `custom_folder`, `omnidocbench`, `doclaynet`. `DatasetDocument` + dataclass; registry pattern for downstream extension. + +### Added — pipeline + +- Iterative repair loop in `pipeline.py`: bounded by + `repair.max_iterations`, terminates on quality-accepted OR + no-changes-this-pass. Per-iteration history under + `provenance.repair_iterations`. +- GPU repair escalation wired into `repair/controller.py`. Plans + same-schema GPU tasks for invalid tables, OCR/text coverage issues, + reading-order failures, and figure issues, then dispatches via + `GPUWorker`. Default safe (`repair.gpu_escalation=true, + repair.execute_gpu_escalations=false`); flip the second to invoke + the configured backend. +- Per-parser candidate snapshots persisted in + `parsed.provenance.candidates` so per-parser GT metrics can be + computed without re-parsing. +- Real Marker and Unstructured normalizers + (`zsgdp/normalize/normalize_marker.py` and + `normalize_unstructured.py`) wired through `parsers/external.py`. + +### Changed + +- `requirements.txt` no longer pins `torch`; the HF Spaces image + preinstalls a CUDA-matched build and pinning here would override it. +- `--gpu-workers` flag help text clarified — the value is recorded for + downstream task-execution accounting but document parsing uses + `--workers`. +- `--dataset` benchmark flag now selects the loader name + (default `custom_folder`); `custom`/`folder`/`default` accepted as + aliases. Previous behaviour was a freeform reporting label only. +- Embedding-retriever toy hashing test now uses + `hashlib.md5`-based stable hashing instead of `builtins.hash()`, + fixing per-process flakiness. + +### Documentation + +- `tests/regression/README.md` documents the fixture format. +- `configs/default.yaml` and `configs/docling.yaml` annotated to + explain the new `repair.execute_gpu_escalations` and the deliberate + Docling+PyMuPDF dual-enable for the disagreement metric. + +### Test count + +- 181 tests pass (was 4 at the start of the eval surface work). + +## [0.1.0] — initial MVP + +- Profiler, page router, parser registry (text, pymupdf, docling, plus + shell-out adapters for marker / mineru / olmocr / paddleocr / + unstructured). +- Canonical schema (`Element`, `TableObject`, `FigureObject`, `Chunk`, + `ParsedDocument`, `QualityReport`). +- Merger with conflict detection, quality verifier (coverage, reading + order, table validity, chunk readiness), deterministic repair + controller. +- Agentic chunker with fixed-token / recursive-structure / parent-child + / page-level / table / figure strategies; semantic / late / + vision-guided / proposition stubs. +- Artifact manifest with SHA-256 checksums, `validate-artifacts` CLI. +- Gradio Spaces entrypoint, `space-check` deployment readiness CLI. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..04c1b16584894963b4f68b5af2693c9e0315c1a3 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,235 @@ +# Contributing to zeroshotGPU + +Thanks for working on this. Three things to know up front: + +1. **Run `make preflight` before pushing.** It's the same suite that runs + in pre-push if you have the hooks installed (see below). A green + preflight is the local signal that the branch is ready for the + [Space smoke checklist](docs/space_smoke.md). +2. **Keep it dependency-light by default.** New runtime dependencies need + a corresponding entry in `pyproject.toml` extras and an explicit + gate (config flag, lazy import, or feature-detection fallback). The + `embedding` extra is the model: opt-in, lazy-imported on first use, + raises a clean `RuntimeError` when missing. +3. **Don't change schema shapes silently.** Bump + `zsgdp.schema.SCHEMA_VERSION` whenever the on-disk shape of + `parsed_document.json`, `chunks.jsonl`, etc. changes. See + [Schema versioning](#schema-versioning) below. + +--- + +## Setup + +```bash +git clone +cd "Document Parser" +python3.11 -m venv .venv && source .venv/bin/activate +python -m pip install -e ".[pdf,yaml,docling,dev]" +``` + +Optional extras: + +- `.[embedding]` — sentence-transformers + transformers for the embedding + retriever. Only needed when you set `benchmarks.retriever.backend=embedding`. +- `.[gpu_repair]` — transformers for live GPU repair. Only needed when you + set `repair.execute_gpu_escalations=true`. +- `.[spaces]` — mirrors the root `requirements.txt` so an editable install + matches a Space deploy. + +Set up `.env` for local secrets: + +```bash +cp .env.example .env +# Fill in HF_TOKEN if you need gated models. +``` + +`.env` is gitignored. CLI and `app.py` load it automatically; pre-set +environment variables always win, so a Space's secrets never get +overridden by a stray local file. + +--- + +## Pre-commit / pre-push hooks + +```bash +python -m pip install pre-commit +pre-commit install --install-hooks --hook-type pre-commit --hook-type pre-push +``` + +Two stages: + +- **pre-commit** — fast static checks: trailing whitespace, end-of-file + newline, JSON/YAML syntax, large-file guard, merge-conflict markers. + Runs on every `git commit`. +- **pre-push** — runs `python -m zsgdp.cli preflight`. Same as + `make preflight`. Failing this blocks the push. + +Skip on a specific commit with `git commit --no-verify` if you genuinely +need to (e.g. WIP). Skip the pre-push gate with `git push --no-verify`, +but only if you have a separately verified preflight run. + +--- + +## Running tests + +```bash +make test # full unittest discover +make regression # snapshot fixture suite +make preflight # everything except the benchmark smoke +make preflight-full # adds an end-to-end benchmark smoke +make benchmark # parses tests/regression/fixtures/ via the CLI +``` + +Or directly: + +```bash +python -m unittest discover +python -m unittest tests.regression.test_regression +python -m zsgdp.cli preflight --root . --benchmark +``` + +Performance regressions are gated behind `ZSGDP_REGRESSION_PERF=1`: + +```bash +ZSGDP_REGRESSION_PERF=1 python -m unittest tests.regression.test_regression +``` + +See [tests/regression/README.md](tests/regression/README.md) for the +fixture format including the `performance` block. + +--- + +## Adding a regression fixture + +1. Drop the input under `tests/regression/fixtures/.input.`. +2. Parse it once locally and inspect the output: + ```bash + python -m zsgdp.cli parse --input tests/regression/fixtures/.input. --output /tmp/sanity + ``` +3. Hand-write `tests/regression/fixtures/.expected.json` with the + tolerances you want to lock down. Prefer ranges over exact counts + where reasonable variance exists. +4. Optional: add a `performance` block with `max_elapsed_seconds` set to + ~50–100x your local median (catastrophic-regression guard, not a + tight bar). +5. Run `make regression` to confirm the fixture is picked up. + +--- + +## Adding a parser adapter + +1. Subclass `BaseParser` in `zsgdp/parsers/_parser.py` (or extend + `external.py` for shell-out adapters). +2. Set `name`, `supported_file_types`, implement `available()` and + `parse(path, profile, config, *, pages=None)`. +3. Register in `zsgdp/parsers/registry.py`. +4. If the parser produces Markdown, write a normalizer under + `zsgdp/normalize/normalize_.py` that returns a `ParseCandidate` + via `normalize_markdown_candidate(...)`. +5. Add a config block to `configs/default.yaml` with `enabled: false` + plus any CLI flags the adapter needs. +6. Add the dependency to `pyproject.toml` as an optional extra. Don't + pin it in the top-level `requirements.txt` unless it's free to + install on every Space build. + +--- + +## Adding a metric + +Pure metrics live under `zsgdp/verify/`: + +1. Define inputs as plain dicts/lists (not `ParsedDocument`-keyed) so + the same metric works on per-parser candidate snapshots, not just + the merged document. +2. Pin definitions in the module docstring — exact denominator, + handling of empty inputs, what each return key means. +3. Surface in `zsgdp/benchmarks/parser_quality.py`: + - Add per-document fields to the `doc_record`. + - Add aggregated means to the top-level `summary` dict. + - Add a per-document CSV writer if it has detail worth its own file. +4. Add tests for: perfect input, no-match input, partial overlap, + vacuous empty/empty case, and a benchmark-integration test that + asserts the metric appears in `summary["documents"][0]`. + +--- + +## Schema versioning + +`zsgdp.schema.SCHEMA_VERSION` lives in +[zsgdp/schema/document.py](zsgdp/schema/document.py). It's surfaced into +`artifact_manifest.json` as `parsed_document_schema_version` so a +consumer reading old output can gate. + +Bump rules: + +- **Additive change** (new optional field with a default) — bump the + patch (1.0 → 1.1). +- **Breaking change** (renamed/removed field, semantics changed) — bump + the major (1.0 → 2.0). Update the regression fixtures in the same + PR; downstream consumers will need a migration. +- **No change** — leave it alone. + +When you bump, add an entry to `CHANGELOG.md` under +"### Schema" with the version and what changed. + +--- + +## Logging + +Use `from zsgdp.logging_config import get_logger` then +`logger = get_logger(__name__)`. Call `.info`/`.warning`/`.error` with +structured `extra={...}` fields rather than f-string-formatted messages +where possible — the JSON formatter promotes `extra` keys to top-level +fields so the HF Spaces logs page is greppable. + +Default log level is WARNING (CLI summaries unaffected). Opt in with +`ZSGDP_LOG_LEVEL=INFO` and `ZSGDP_LOG_JSON=1` for Space-style output. + +--- + +## Pull request checklist + +Before opening a PR: + +- [ ] `make preflight` passes locally. +- [ ] If you added a metric, an adapter, or changed the schema, you + updated `CHANGELOG.md`. +- [ ] If you changed parser behavior, you ran `make regression` and any + fixture drift is intentional (and the snapshot was regenerated + explicitly). +- [ ] If your change touches GPU/model code paths, you flagged it for + Space-side smoke testing in the PR description (the + [smoke checklist](docs/space_smoke.md) covers what to run). +- [ ] You did **not** commit `.env` or any secret. The `.gitignore` + should catch this; if you suspect a leak, treat the token as + compromised and rotate it. + +--- + +## Architecture quick map + +- `zsgdp/profiling/` — page-level features and labels. +- `zsgdp/routing/` — deterministic page → expert mapping. +- `zsgdp/parsers/` — adapters; one canonical schema regardless of source. +- `zsgdp/normalize/` — convert each parser's output into the schema. +- `zsgdp/merge/` — align candidates, dedupe, detect conflicts. +- `zsgdp/verify/` — coverage, reading order, table/figure/formula/chunk + quality, GT-comparison metrics (layout F1, table structure, formula + CER, retrieval recall), parser disagreement and repair success rates. +- `zsgdp/repair/` — deterministic header/table fixes plus GPU + escalation that dispatches to `gpu/worker.py`. +- `zsgdp/chunking/` — agentic planner + structure-aware / parent-child / + table / figure / page chunk builders, with semantic / late / + vision-guided / proposition deterministic stubs. +- `zsgdp/gpu/` — task planning, batching, dry-run worker, transformers + and vLLM clients. +- `zsgdp/benchmarks/` — dataset loaders, metric runners, ablation, + cross-dataset comparison, retrieval (lexical + embedding). +- `zsgdp/cli.py` — single entry point exposing all of the above. +- `app.py` — Gradio Space front-end. + +The full spec lives in +[zero_shot_gpu_document_parser_project_spec.md](zero_shot_gpu_document_parser_project_spec.md). +The 2000-line read isn't required to contribute, but section §10 (schema) +and §17 (chunking ladder) are worth skimming if you're touching those +modules. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..5b013f8fdb924702fded3470c7dcd5a2d03876db --- /dev/null +++ b/Makefile @@ -0,0 +1,49 @@ +PYTHON ?= python3.11 + +.PHONY: help test regression space-check parsers preflight preflight-full benchmark space-smoke space-smoke-strict clean + +help: + @echo "Targets:" + @echo " test - run the full unittest discover suite" + @echo " regression - run the regression fixture snapshot suite" + @echo " space-check - run the HF Space readiness check" + @echo " parsers - print the parser registry" + @echo " preflight - run test + regression + space-check + parsers" + @echo " preflight-full - preflight + an end-to-end benchmark smoke" + @echo " benchmark - run zsgdp benchmark against tests/regression/fixtures" + @echo " space-smoke - run docs/space_smoke.md smokes (deps-permitting)" + @echo " space-smoke-strict - same, but treat skipped smokes as failures" + @echo " clean - remove __pycache__ and benchmark output" + +test: + $(PYTHON) -m unittest discover + +regression: + $(PYTHON) -m unittest tests.regression.test_regression -v + +space-check: + $(PYTHON) -m zsgdp.cli space-check --root . + +parsers: + $(PYTHON) -m zsgdp.cli parsers + +preflight: + $(PYTHON) -m zsgdp.cli preflight --root . + +preflight-full: + $(PYTHON) -m zsgdp.cli preflight --root . --benchmark + +benchmark: + $(PYTHON) -m zsgdp.cli benchmark \ + --input tests/regression/fixtures \ + --output out/preflight_benchmark + +space-smoke: + $(PYTHON) -m scripts.run_space_smoke --output out/space_smoke_report.json + +space-smoke-strict: + $(PYTHON) -m scripts.run_space_smoke --strict --output out/space_smoke_report.json + +clean: + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + rm -rf out/preflight_benchmark diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f274b060c09baa6d83aa42f37df7d3c60772ca30 --- /dev/null +++ b/README.md @@ -0,0 +1,287 @@ +--- +title: zeroshotGPU +sdk: gradio +app_file: app.py +python_version: 3.11 +suggested_hardware: l4x1 +short_description: Agentic zero-shot document parser with parser metrics and chunk artifacts. +--- + +# Zero-Shot GPU Document Parser + +A self-hosted parsing control plane that profiles documents, routes pages to +parser experts, normalizes outputs, verifies quality with GT-comparison +metrics, repairs weak regions through a bounded verify/repair loop (with +optional GPU escalation), and emits auditable parsed-document artifacts plus +strategy-aware chunks. Implements the project described in +[`zero_shot_gpu_document_parser_project_spec.md`](zero_shot_gpu_document_parser_project_spec.md). + +The codebase is intentionally dependency-light by default. Text and Markdown +work with the standard library; PyMuPDF, Docling, Marker, MinerU, olmOCR, +PaddleOCR, and Unstructured plug in via optional extras. Live GPU repair +(Qwen2.5-VL-3B) and the embedding retriever (jina-embeddings-v3) are gated +behind explicit config flags so a fresh clone never silently downloads +multi-gigabyte weights. + +--- + +## Install + +For the local MVP (text + PyMuPDF + Docling): + +```bash +python -m pip install -e ".[pdf,yaml,docling,dev]" +``` + +Optional extras: + +| Extra | Adds | Required for | +|---------------|--------------------------------------------------|-----------------------------------------------| +| `embedding` | `sentence-transformers`, `transformers` | `benchmarks.retriever.backend=embedding` | +| `gpu_repair` | `transformers` | `repair.execute_gpu_escalations=true` | +| `spaces` | mirrors `requirements.txt` for HF Spaces parity | running `app.py` locally as a Space simulant | + +External parser CLIs (Marker, MinerU, olmOCR, PaddleOCR) install separately; +configure each via `parsers..command`, `output_args`, and `extra_args` +in your YAML config. + +Secrets: + +```bash +cp .env.example .env +# Set HF_TOKEN if you'll use gated models (jina-embeddings-v3, private repos). +``` + +`.env` is gitignored. The CLI and `app.py` load it on startup; pre-set +environment variables (e.g. Space-side secrets) always win. + +--- + +## Quick start + +### Parse one document or a folder + +```bash +python -m zsgdp.cli parse --input ./docs/sample.md --output ./out/sample +python -m zsgdp.cli parse-folder --input ./docs --output ./parsed --workers 4 +python -m zsgdp.cli parse --input ./docs/report.pdf --output ./out/report --config configs/docling.yaml +``` + +Each parse writes a full artifact bundle. `parsed_document.json` is the +canonical record; `chunks.jsonl` is the retrieval-ready output; +`quality_report.json` carries every metric the verifier computed. + +### Run a benchmark + +```bash +# Custom corpus, no GT — runs every metric that doesn't need labels: +python -m zsgdp.cli benchmark --input ./docs --output ./bench + +# Labelled datasets — adds layout F1 / table structure / formula CER: +python -m zsgdp.cli benchmark --input ./omnidocbench --dataset omnidocbench --output ./bench/omni +python -m zsgdp.cli benchmark --input ./doclaynet --dataset doclaynet --output ./bench/doclay +``` + +### Compare parsers (ablation) + +```bash +python -m zsgdp.cli benchmark-ablate \ + --input ./docs --output ./bench/ablation \ + --parser docling --parser pymupdf --parser text +``` + +Runs the benchmark once per parser plus a merged arm; emits +`ablation_comparison.csv`. + +### Compare across datasets + +```bash +python -m zsgdp.cli combine-benchmarks \ + --input ./bench/omni --label omnidocbench \ + --input ./bench/doclay --label doclaynet \ + --output ./bench/cross +``` + +Emits `dataset_summary.csv` and `parser_matrix.csv` (parser × dataset). + +### Before pushing to a Space — preflight + +```bash +make preflight # unit + regression + space-check + parsers (~10s) +make preflight-full # ...plus an end-to-end benchmark smoke +``` + +A green preflight is the local signal that the branch is ready for the +Space. Pre-commit and pre-push hooks (see [CONTRIBUTING.md](CONTRIBUTING.md)) +make this automatic on every `git push`. + +### On the Space — smoke validation + +Once deployed, exercise the deferred GPU/model paths: + +```bash +make space-smoke # runs whichever of 5 smokes have their deps +python -m scripts.run_space_smoke --strict --output ./space_smoke.json +``` + +See [docs/space_smoke.md](docs/space_smoke.md) for the manual fallback +procedure (real PDF uploads, full Marker parses) and per-smoke +acceptance criteria. + +--- + +## Opt-ins + +### Embedding retriever + +Default retriever is lexical TF-IDF (zero deps). To use a real embedder: + +```yaml +# configs/myrun.yaml +benchmarks: + retriever: + backend: embedding + model_id: jinaai/jina-embeddings-v3 # or any sentence-transformers model + task: retrieval.passage +``` + +```bash +python -m pip install -e ".[embedding]" +python -m zsgdp.cli benchmark --input ./docs --output ./bench --config configs/myrun.yaml +``` + +The first call lazy-loads the model; subsequent calls reuse it in-process. +Set `HF_TOKEN` in `.env` for gated models. + +### Live GPU repair + +The repair controller plans GPU tasks for verification failures (invalid +tables, OCR coverage gaps, reading-order issues, missing figure captions). +By default these are dry-run only. To execute: + +```yaml +# configs/myrun.yaml +repair: + gpu_escalation: true + execute_gpu_escalations: true # invokes the configured backend +gpu: + backend: transformers # or "vllm" for OpenAI-compat + models: + table: + model_id: Qwen/Qwen2.5-VL-3B-Instruct +``` + +Each executed task writes its output back into the merged document with a +`gpu_repair_task_id` provenance field. + +--- + +## Outputs + +Every parse writes: + +- `parsed_document.json` — canonical record (carries `schema_version`). +- `document.md` — human-readable Markdown reconstruction. +- `elements.jsonl` / `tables.jsonl` / `figures.jsonl` / `chunks.jsonl` — JSONL streams. +- `chunking_plan.json` — strategy ladder + per-strategy metadata. +- `parser_metrics.json` — per-parser candidate-level stats. +- `quality_report.json` — every verifier metric (text coverage, reading order, table validity, parser disagreement, repair resolution/regression rates, GT-comparison metrics when applicable). +- `routing_report.json` — page → parser routing decisions. +- `profile.json` — document profiler output. +- `gpu_runtime.json` — detected GPU/device state at parse time. +- `gpu_tasks.jsonl` (when model-backed work is planned) and `gpu_task_report.json` (preflight validation). +- `conflict_report.json` (when multiple parsers ran). +- `artifact_manifest.json` with SHA-256 checksums and the parsed-document schema version. +- `assets/pages/*.png`, `assets/tables/*.png`, `assets/figures/*.png` — rendered PDF page and region crops. + +Benchmark runs additionally write: + +- `results.json` — full structured summary including aggregate means. +- `leaderboard.csv` and `per_parser_gt_leaderboard.csv` — parser leaderboards (without and with GT comparison). +- `per_parser_metrics.csv` — per-document, per-parser GT-comparison breakdown. +- `layout_runs.csv`, `table_structure_runs.csv`, `formula_runs.csv`, `retrieval_runs.csv`, `repair_runs.csv` — per-document detail per metric family. +- `parser_runs.csv`, `chunk_runs.csv`, `structure_runs.csv`, `chunk_quality.csv`, `throughput_runs.csv`, `ablations.json` — additional detail. + +`benchmark-ablate` adds `ablation_comparison.csv`. `combine-benchmarks` +adds `dataset_summary.csv`, `parser_matrix.csv`, and +`cross_dataset_comparison.json`. + +--- + +## Architecture map + +| Module | Responsibility | +|-------------------------|-------------------------------------------------------------------------| +| `zsgdp/profiling/` | Cheap per-page features (scanned-score, table density, columns, etc.) | +| `zsgdp/routing/` | Deterministic page → parser-expert decisions with budget | +| `zsgdp/parsers/` | Adapters; one canonical schema regardless of source | +| `zsgdp/normalize/` | Convert each parser's output into the schema | +| `zsgdp/merge/` | Align candidates, dedupe, detect conflicts | +| `zsgdp/verify/` | Coverage / reading order / table / figure / formula / chunk readiness, plus GT-comparison: layout F1, table structure, formula CER, retrieval recall, parser disagreement, repair success | +| `zsgdp/repair/` | Deterministic header/table fixes plus GPU escalation through `gpu/worker.py` | +| `zsgdp/chunking/` | Agentic planner + structure / parent-child / table / figure / page chunkers, with semantic / late / vision / proposition deterministic stubs | +| `zsgdp/gpu/` | Task planning, batching, dry-run worker, transformers + vLLM clients | +| `zsgdp/benchmarks/` | Dataset loaders, metric runners, ablation, cross-dataset, retrieval | +| `zsgdp/cli.py` | All entry points | +| `app.py` | Gradio Space UI | + +The full spec is in [`zero_shot_gpu_document_parser_project_spec.md`](zero_shot_gpu_document_parser_project_spec.md). §10 (schema) and §17 (chunking ladder) are the most useful sections to skim before touching those modules. + +--- + +## Production benchmark numbers + +Once the Space deploy is live and `make space-smoke` is green, run the +benchmark against your representative corpus and paste the headline +metrics here. Spec §29 success criteria for reference: + +- **MVP:** full agentic loop improves table QA by ≥20% over best single parser; agentic chunking improves citation accuracy by ≥10% over recursive baseline. +- **Production-style (HR / financial reports / etc.):** retrieval recall@5 ≥ 90%, citation accuracy ≥ 90%, table QA exactness ≥ 85%, manual review rate ≤ 10%, parser blocking failure rate ≤ 5%. + +| Metric | Dataset / Corpus | Value | Date | Run | +|---------------------------------|------------------|-------|------|-----| +| `mean_quality_score` | _todo_ | _todo_| _todo_ | _todo_ | +| `mean_layout_f1` | _todo_ | _todo_| _todo_ | _todo_ | +| `mean_table_structure_score` | _todo_ | _todo_| _todo_ | _todo_ | +| `mean_formula_cer` | _todo_ | _todo_| _todo_ | _todo_ | +| `mean_retrieval_recall_at_5` | _todo_ | _todo_| _todo_ | _todo_ | +| `mean_parser_disagreement_rate` | _todo_ | _todo_| _todo_ | _todo_ | +| `mean_repair_resolution_rate` | _todo_ | _todo_| _todo_ | _todo_ | +| `mean_pages_per_second` | _todo_ | _todo_| _todo_ | _todo_ | + +Source rows are individual `results.json` files under each Space-side +benchmark output; commit the directory or a redacted summary so the +numbers above are reproducible. + +--- + +## Deployment + +Targeted: Hugging Face Spaces, hardware `l4x1`, GPU/model target +`zeroshotGPU`. + +Pre-deploy gate: + +1. `make preflight` (local). +2. `make preflight-full` (local with end-to-end benchmark smoke). +3. Duplicate the Space, set `HF_TOKEN` and any other secrets in **Variables and secrets**. +4. Push. +5. `make space-smoke` from the Space's JupyterLab terminal. +6. Inspect [docs/space_smoke.md](docs/space_smoke.md) Smoke 3 (live GPU repair) manually if the runner-level wiring smoke passed but you want full model-invocation validation. +7. Run `python -m zsgdp.cli benchmark` against your representative corpus and update the table above. + +The Space defaults to `configs/docling.yaml` (Docling + PyMuPDF +co-enabled so the parser disagreement rate has signal). Override via +`ZSGDP_CONFIG_PATH` in Space variables for custom configs. + +--- + +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for setup, hooks, test layout, +fixture format, parser/metric/schema-bump procedures, and the PR checklist. + +For changes touching the on-disk schema, bump `zsgdp.schema.SCHEMA_VERSION` +and add an entry under `### Schema` in [CHANGELOG.md](CHANGELOG.md). The +artifact manifest surfaces the version under +`parsed_document_schema_version` so downstream consumers can gate. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..36e17365beb0addc214110809d7fdfaa2e75a47c --- /dev/null +++ b/app.py @@ -0,0 +1,251 @@ +"""Hugging Face Spaces entrypoint for zeroshotGPU.""" + +from __future__ import annotations + +import os +import shutil +import tempfile +from pathlib import Path +from typing import Any + +try: + import gradio as gr +except ImportError as exc: # pragma: no cover - only used when launching the Space UI. + raise RuntimeError("Gradio is required for the Spaces UI. Install with `python -m pip install -r requirements.txt`.") from exc + +from zsgdp.artifacts import validate_artifact_manifest +from zsgdp.config import load_config, load_env_file +from zsgdp.gpu import collect_gpu_runtime_status +from zsgdp.logging_config import configure_logging, get_logger +from zsgdp.pipeline import parse_document +from zsgdp.profiling import profile_document + +# Load .env first so any keys it sets (HF_TOKEN, ZSGDP_LOG_LEVEL, etc.) are +# visible before we read environment defaults below. Pre-set Space variables +# always win — load_env_file does not override existing env entries. +load_env_file() + +# Default to JSON logs on the Space so the HF Spaces logs page is greppable. +# Override locally with `ZSGDP_LOG_JSON=0` for human-readable text output. +os.environ.setdefault("ZSGDP_LOG_LEVEL", "INFO") +os.environ.setdefault("ZSGDP_LOG_JSON", "1" if os.environ.get("SPACE_ID") else "0") +configure_logging() +_logger = get_logger(__name__) + +ROOT = Path(__file__).resolve().parent +DOCLING_CONFIG = ROOT / "configs" / "docling.yaml" + +# Abuse guards. Override at deployment time via env vars to relax for trusted +# Spaces or tighten further for public ones. +MAX_UPLOAD_BYTES = int(os.environ.get("ZSGDP_MAX_UPLOAD_BYTES", str(50 * 1024 * 1024))) # 50 MB +MAX_PAGE_COUNT = int(os.environ.get("ZSGDP_MAX_PAGE_COUNT", "200")) + + +class UploadRejected(Exception): + """Raised when an upload exceeds an abuse-guard limit.""" + + +def _validate_upload(path: Path) -> None: + """Reject oversized uploads or PDFs with too many pages before parsing. + + Cheap to compute (file stat + profiler page count) and avoids spending + GPU/CPU minutes on inputs the Space wasn't sized for. + """ + + if not path.exists(): + raise UploadRejected("Uploaded file is missing on disk.") + size = path.stat().st_size + if size > MAX_UPLOAD_BYTES: + raise UploadRejected( + f"Upload is {size / 1024 / 1024:.1f} MB; the Space limit is " + f"{MAX_UPLOAD_BYTES / 1024 / 1024:.0f} MB. Set ZSGDP_MAX_UPLOAD_BYTES to override." + ) + try: + profile = profile_document(path) + except Exception: # pragma: no cover - profiler is robust; this is belt-and-braces. + return + if profile.page_count > MAX_PAGE_COUNT: + raise UploadRejected( + f"Document has {profile.page_count} pages; the Space limit is " + f"{MAX_PAGE_COUNT}. Set ZSGDP_MAX_PAGE_COUNT to override." + ) + + +# Top-level artifact files surfaced as individual downloads. Nested +# directories like assets/ stay bundled in the zip only — they can be +# large for multi-page PDFs and would clutter the per-artifact list. +_INDIVIDUAL_ARTIFACT_NAMES = ( + "parsed_document.json", + "document.md", + "elements.jsonl", + "tables.jsonl", + "figures.jsonl", + "chunks.jsonl", + "chunking_plan.json", + "parser_metrics.json", + "quality_report.json", + "routing_report.json", + "profile.json", + "gpu_runtime.json", + "gpu_tasks.jsonl", + "gpu_task_report.json", + "artifact_manifest.json", + "conflict_report.json", +) + + +def _collect_artifact_files(output_dir: Path) -> list[str]: + """Return absolute paths for the top-level artifacts the Space surfaces. + + Order matches _INDIVIDUAL_ARTIFACT_NAMES so the UI listing is stable. + Missing files are silently skipped (different parse runs emit different + subsets — e.g. conflict_report.json only when multiple parsers ran). + """ + + paths: list[str] = [] + for name in _INDIVIDUAL_ARTIFACT_NAMES: + candidate = output_dir / name + if candidate.exists(): + paths.append(str(candidate)) + return paths + + +def _empty_outputs(reason: str, source: Path | None, *, rejected: bool, runtime: dict) -> tuple: + """Return-shape used for every error path. Centralised so the tuple width + can't drift between the success path and the four error paths.""" + + summary: dict[str, Any] = {"error": reason} + if source is not None: + summary["source"] = str(source) + if rejected: + summary["rejected"] = True + return ("", summary, {}, {}, {}, runtime, [], {}, {}, None, []) + + +def parse_uploaded_document(file_obj: Any, pipeline_mode: str): + if file_obj is None: + return _empty_outputs("Upload a document first.", None, rejected=False, runtime={}) + + source = Path(file_obj.name) + work_dir = Path(tempfile.mkdtemp(prefix="zeroshotgpu_")) + output_dir = work_dir / "parsed" + config_path = _config_path_for_mode(pipeline_mode) + + try: + _validate_upload(source) + except UploadRejected as exc: + _logger.warning( + "space_upload_rejected", + extra={"source_path": str(source), "reason": str(exc)}, + ) + runtime = runtime_status_for_mode(pipeline_mode) + return _empty_outputs(str(exc), source, rejected=True, runtime=runtime) + + try: + parsed = parse_document(source, output_dir, config_path=config_path) + except Exception as exc: # pragma: no cover - surfaced in the Space UI. + runtime = runtime_status_for_mode(pipeline_mode) + return _empty_outputs(str(exc), source, rejected=False, runtime=runtime) + + artifact_validation = validate_artifact_manifest(output_dir) + archive_path = shutil.make_archive(str(output_dir), "zip", output_dir) + individual_files = _collect_artifact_files(output_dir) + runtime = parsed.provenance.get("gpu_runtime", {}) + summary = { + "doc_id": parsed.doc_id, + "file_type": parsed.file_type, + "elements": len(parsed.elements), + "tables": len(parsed.tables), + "figures": len(parsed.figures), + "chunks": len(parsed.chunks), + "quality_score": parsed.quality_report.score, + "blocking": parsed.quality_report.has_blocking_failures, + "deployment": parsed.provenance.get("config_deployment", {}), + "runtime_device": runtime.get("device"), + "running_on_huggingface_space": runtime.get("running_on_huggingface_space"), + "artifact_manifest_valid": artifact_validation.get("valid"), + "artifact_count": artifact_validation.get("artifact_count"), + "artifact_checked_count": artifact_validation.get("checked_count"), + "individual_artifact_count": len(individual_files), + } + return ( + parsed.to_markdown(), + summary, + parsed.quality_report.to_dict(), + parsed.provenance.get("parser_metrics", {}), + parsed.provenance.get("chunking", {}), + runtime, + parsed.provenance.get("gpu_tasks", []), + parsed.provenance.get("gpu_task_report", {}), + artifact_validation, + archive_path, + individual_files, + ) + + +def _config_path_for_mode(pipeline_mode: str) -> Path | None: + env_config = os.environ.get("ZSGDP_CONFIG_PATH") + if env_config: + return Path(env_config) + if pipeline_mode == "Docling + PyMuPDF" and DOCLING_CONFIG.exists(): + return DOCLING_CONFIG + return None + + +def runtime_status_for_mode(pipeline_mode: str) -> dict: + return collect_gpu_runtime_status(load_config(_config_path_for_mode(pipeline_mode))).to_dict() + + +with gr.Blocks(title="zeroshotGPU") as demo: + gr.Markdown("# zeroshotGPU") + with gr.Row(): + upload = gr.File(label="Document", file_types=[".pdf", ".md", ".txt", ".html"]) + with gr.Column(): + pipeline = gr.Dropdown( + choices=["Docling + PyMuPDF", "Default lightweight"], + value="Docling + PyMuPDF", + label="Pipeline", + ) + parse_button = gr.Button("Parse", variant="primary") + archive = gr.File(label="Artifacts (zip)") + with gr.Tabs(): + with gr.Tab("Markdown"): + markdown = gr.Markdown(label="Canonical Markdown") + with gr.Tab("Run"): + summary = gr.JSON(label="Summary") + quality = gr.JSON(label="Quality Report") + parser_metrics = gr.JSON(label="Parser Metrics") + chunking = gr.JSON(label="Chunking Plan") + artifact_validation = gr.JSON(label="Artifact Manifest Validation") + with gr.Tab("Artifacts"): + gr.Markdown( + "Each top-level artifact is downloadable individually. " + "Nested assets (page renders, table/figure crops) stay bundled " + "in the zip above." + ) + individual_artifacts = gr.Files(label="Individual artifacts") + with gr.Tab("Runtime"): + runtime = gr.JSON(label="GPU Runtime", value=runtime_status_for_mode("Docling + PyMuPDF")) + gpu_tasks = gr.JSON(label="Planned GPU Tasks") + gpu_task_report = gr.JSON(label="GPU Task Preflight") + parse_button.click( + parse_uploaded_document, + inputs=[upload, pipeline], + outputs=[ + markdown, + summary, + quality, + parser_metrics, + chunking, + runtime, + gpu_tasks, + gpu_task_report, + artifact_validation, + archive, + individual_artifacts, + ], + ) + + +if __name__ == "__main__": + demo.launch() diff --git a/configs/default.yaml b/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33c7743d92622b0ed0c09943922eddb65a1777b9 --- /dev/null +++ b/configs/default.yaml @@ -0,0 +1,159 @@ +parsers: + text: + enabled: true + pymupdf: + enabled: true + docling: + enabled: false + do_ocr: false + do_table_structure: false + force_backend_text: true + marker: + enabled: false + command: null + timeout_seconds: 300 + output_args: "--output_dir {output_dir} --output_format markdown" + extra_args: "" + mineru: + enabled: false + command: null + timeout_seconds: 600 + output_args: "--output_dir {output_dir}" + extra_args: "" + olmocr: + enabled: false + command: null + timeout_seconds: 600 + output_args: "--output_dir {output_dir}" + extra_args: "" + paddleocr: + enabled: false + command: null + timeout_seconds: 600 + output_args: "--output_dir {output_dir}" + extra_args: "" + unstructured: + enabled: false + +routing: + run_multiple_on_hard_pages: true + max_primary_parsers_per_page: 2 + hard_page_threshold: 0.65 + scanned_text_threshold: 0.40 + table_density_threshold: 0.25 + formula_density_threshold: 0.15 + figure_density_threshold: 0.20 + +repair: + enabled: true + max_iterations: 3 + # Plan and dry-run GPU escalations for verification failures. + gpu_escalation: true + # Actually invoke the configured GPU/VLM backend on flagged regions. + # Defaults to false to avoid surprise model downloads on local runs; + # set true on the Space once GPU models are warm. + execute_gpu_escalations: false + table_repair: true + reading_order_repair: true + figure_repair: true + ocr_repair: true + +gpu: + backend: transformers + provider: huggingface_spaces + space_name: zeroshotGPU + batch_pages: true + validate_tasks: true + max_batch_size: 4 + max_gpu_seconds_per_doc: 120 + max_vlm_calls_per_doc: 30 + models: + vlm: + model_id: Qwen/Qwen2.5-VL-3B-Instruct + task: image-text-to-text + device: auto + dtype: bfloat16 + max_batch_size: 1 + ocr: + model_id: Qwen/Qwen2.5-VL-3B-Instruct + task: document-ocr + device: auto + dtype: bfloat16 + max_batch_size: 1 + table: + model_id: Qwen/Qwen2.5-VL-3B-Instruct + task: table-repair + device: auto + dtype: bfloat16 + max_batch_size: 1 + embedding: + model_id: jinaai/jina-embeddings-v3 + task: retrieval.passage + device: auto + dtype: bfloat16 + max_batch_size: 16 + task_model_roles: + vlm_route_repair: vlm + ocr_page: ocr + table_vlm_repair: table + figure_description: vlm + +pdf: + render_pages: true + render_dpi: 150 + crop_tables: true + crop_figures: true + asset_dir: assets + +quality: + accept_threshold: 0.88 + blocking_failures: + - empty_page + - invalid_table + - missing_text_coverage + - reading_order_failure + +chunking: + enabled: true + planner: agentic + baseline_strategy: recursive_structure + target_tokens: 512 + min_tokens: 120 + overlap_ratio: 0.15 + parent_child: true + parent_target_tokens: 1600 + page_level_for_paginated_docs: true + table_chunks: true + figure_chunks: true + contextual_prefix: false + contextual_retrieval: false + semantic_similarity_threshold: 0.18 + max_propositions_per_source: 8 + max_proposition_chunks: 64 + semantic_chunking: false + late_chunking: false + vision_guided: false + agentic_proposition_chunking: false + strategy_ladder: + - fixed_token_baseline + - recursive_structure + - metadata_enriched + - parent_child + - contextual_retrieval + - late_chunking + - semantic_chunking + - vision_guided + - agentic_proposition + +benchmarks: + retriever: + # `lexical` (default, model-free TF-IDF) or `embedding` (sentence-transformers). + # The `embedding` backend pulls model_id and task from gpu.models.embedding + # unless overridden here. Requires `pip install sentence-transformers`. + backend: lexical + model_id: null + task: null + +deployment: + target: huggingface_spaces + gpu_models_target: zeroshotGPU diff --git a/configs/docling.yaml b/configs/docling.yaml new file mode 100644 index 0000000000000000000000000000000000000000..146d63161c1b39345199f8081988c2bbca0612fc --- /dev/null +++ b/configs/docling.yaml @@ -0,0 +1,29 @@ +parsers: + # Both docling and pymupdf are enabled deliberately so the parser + # disagreement-rate metric has a comparison surface on PDF inputs. + # Disable one if you only need a single-parser baseline. + docling: + enabled: true + do_ocr: false + do_table_structure: false + force_backend_text: true + generate_page_images: false + generate_picture_images: false + generate_table_images: false + do_picture_description: false + do_picture_classification: false + do_formula_enrichment: false + do_code_enrichment: false + marker: + enabled: false + pymupdf: + enabled: true + +routing: + run_multiple_on_hard_pages: true + max_primary_parsers_per_page: 2 + +pdf: + render_pages: true + crop_tables: true + crop_figures: true diff --git a/configs/gpu.yaml b/configs/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..70aa60ed1e40f088eb2e20a693450df1e2901f23 --- /dev/null +++ b/configs/gpu.yaml @@ -0,0 +1,43 @@ +gpu: + backend: transformers + provider: huggingface_spaces + space_name: zeroshotGPU + batch_pages: true + validate_tasks: true + max_batch_size: 4 + max_gpu_seconds_per_doc: 120 + max_vlm_calls_per_doc: 30 + models: + vlm: + model_id: Qwen/Qwen2.5-VL-3B-Instruct + task: image-text-to-text + device: auto + dtype: bfloat16 + max_batch_size: 1 + ocr: + model_id: Qwen/Qwen2.5-VL-3B-Instruct + task: document-ocr + device: auto + dtype: bfloat16 + max_batch_size: 1 + table: + model_id: Qwen/Qwen2.5-VL-3B-Instruct + task: table-repair + device: auto + dtype: bfloat16 + max_batch_size: 1 + embedding: + model_id: jinaai/jina-embeddings-v3 + task: retrieval.passage + device: auto + dtype: bfloat16 + max_batch_size: 16 + task_model_roles: + vlm_route_repair: vlm + ocr_page: ocr + table_vlm_repair: table + figure_description: vlm + +deployment: + target: huggingface_spaces + gpu_models_target: zeroshotGPU diff --git a/configs/parsers.yaml b/configs/parsers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..31c3029f8aec0963c617d90de091beb29a3b5d35 --- /dev/null +++ b/configs/parsers.yaml @@ -0,0 +1,33 @@ +parsers: + text: + enabled: true + pymupdf: + enabled: true + docling: + enabled: false + marker: + enabled: false + command: null + timeout_seconds: 300 + output_args: "--output_dir {output_dir} --output_format markdown" + extra_args: "" + mineru: + enabled: false + command: null + timeout_seconds: 600 + output_args: "--output_dir {output_dir}" + extra_args: "" + olmocr: + enabled: false + command: null + timeout_seconds: 600 + output_args: "--output_dir {output_dir}" + extra_args: "" + paddleocr: + enabled: false + command: null + timeout_seconds: 600 + output_args: "--output_dir {output_dir}" + extra_args: "" + unstructured: + enabled: false diff --git a/configs/routing.yaml b/configs/routing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd7c0f6ec548b37e113db13b604a2b0a0fe471a3 --- /dev/null +++ b/configs/routing.yaml @@ -0,0 +1,8 @@ +routing: + run_multiple_on_hard_pages: true + max_primary_parsers_per_page: 2 + hard_page_threshold: 0.65 + scanned_text_threshold: 0.40 + table_density_threshold: 0.25 + formula_density_threshold: 0.15 + figure_density_threshold: 0.20 diff --git a/docs/space_smoke.md b/docs/space_smoke.md new file mode 100644 index 0000000000000000000000000000000000000000..091a6d7ce9035054aa7cd9a2f713ecdb6f54aa00 --- /dev/null +++ b/docs/space_smoke.md @@ -0,0 +1,269 @@ +# Hugging Face Space smoke-test checklist + +This is the deferred deployment-readiness work that can only be exercised on +real GPU hardware against real models / external CLIs. Run each smoke once +against a duplicated `zeroshotGPU` Space (or your own dev Space). Each entry +gives the exact env vars / config flips, the command to trigger, and the +structured log lines you should expect. + +All log lines below assume the Space is run with `ZSGDP_LOG_LEVEL=INFO` and +`ZSGDP_LOG_JSON=1`. `app.py` sets these automatically when `SPACE_ID` is in +the environment, so on a normal Space you do not need to set them yourself. +The HF Spaces logs page will surface the JSON records on stderr. + +--- + +## Pre-flight + +1. Duplicate the Space, give it `l4x1` hardware. +2. Make sure these are set in **Space settings → Variables and secrets**: + - `ZSGDP_LOG_LEVEL=INFO` + - `ZSGDP_LOG_JSON=1` + - (Optional, only for parser smokes that hit a private repo) `HF_TOKEN`. +3. In the Space's `requirements.txt`, uncomment the dependency block matching + the smoke you are running. Do **one smoke per Space deploy** — combining + them risks an OOM or slow cold-start on the L4. +4. Push and wait for the Space to build. First-build cold-start with a model + download is ~5-10 minutes; subsequent restarts are seconds. + +After deploy, watch the **Logs** tab for the `parse_start` event. If you do +not see structured JSON lines there, the logging config is not active — +double-check `ZSGDP_LOG_JSON=1` in the Space variables. + +## Automated runner + +Each smoke below has an automated counterpart in +`scripts/run_space_smoke.py`. From a Space JupyterLab terminal (or any +shell with the project installed): + +```bash +# Run all smokes whose deps are installed; skip the rest with hints: +python -m scripts.run_space_smoke --output ./space_smoke_report.json + +# Run only specific smokes: +python -m scripts.run_space_smoke --smoke lexical --smoke ablation + +# CI-strict mode: treat skipped smokes as failures (use after you've +# uncommented the deps for the smoke you intend to run): +python -m scripts.run_space_smoke --smoke embedding --strict +``` + +The runner reports `pass` / `fail` / `skip` / `error` per smoke, plus +elapsed seconds and a `detail` block with the metrics it gathered. The +manual procedure below is the fallback when you want to inspect the UI +directly or test something the runner doesn't cover (e.g. uploading a +specific real PDF rather than a synthetic fixture). + +--- + +## Smoke 1 — Lexical retriever benchmark (model-free) + +Confirms the Space's parsing + benchmark plumbing works end-to-end before +adding any model dependency. + +**Setup:** +- Default `requirements.txt` (no uncommenting needed). +- Default config (no flips). + +**Trigger:** upload a small markdown file via the Gradio UI. + +**Expected log lines (in order):** +- `parse_start` with `doc_id`, `file_type`, `device` (likely `cuda`). +- One `parser_candidate` per parser that ran (typically `text`, possibly + `pymupdf` and `docling` if the file was a PDF). +- Possibly one or more `repair_iteration` records if quality < threshold. +- `parse_end` with `quality_score`, `repair_iterations`, `chunk_count`. + +**Pass criteria:** +- All log lines appear with `doc_id` populated. +- `parse_end.quality_score >= 0.85` for a clean markdown doc. +- No `parser_failed` or `gpu_task_blocked` records. + +--- + +## Smoke 2 — Embedding retriever (jina-embeddings-v3) + +Confirms `sentence-transformers` lazy-load path and that jina-v3 specifically +runs on the L4 with `trust_remote_code=True`. + +**Setup:** +- In `requirements.txt`, uncomment `transformers` and `sentence-transformers` + lines. +- Add `configs/space_embedding.yaml` to the repo with: + + ```yaml + benchmarks: + retriever: + backend: embedding + model_id: jinaai/jina-embeddings-v3 + task: retrieval.passage + ``` + +- In `app.py` set `os.environ["ZSGDP_CONFIG_PATH"] = "configs/space_embedding.yaml"`, + or pass via the env var configured in Space variables. + +**Trigger:** upload any markdown / PDF; the benchmark CLI is not reachable +from the Gradio UI today, so for the embedding-retriever smoke you'd need +to run `zsgdp benchmark --input ./fixtures --output ./out` from a Space +**JupyterLab** session against a small input dir. + +**Expected log lines:** +- First call: a 30–90s pause while jina-v3 weights download (no log lines + during this — torch logs go to its own logger). Then `parse_start`. +- After the first parse, subsequent calls are fast (model is in memory). + +**Pass criteria:** +- Benchmark completes without an exception. +- `summary["mean_retrieval_recall_at_5"] >= 0.7` on a small distinct-text + corpus. +- No `gpu_task_blocked` records (those are repair-related, not retrieval). +- The parse_end record's `device` field reads `cuda`. + +**Failure modes to watch:** +- `RuntimeError: EmbeddingRetriever requires sentence-transformers` → + package not in `requirements.txt`. +- CUDA OOM → switch to a smaller embedding model + (`sentence-transformers/all-MiniLM-L6-v2`) for the smoke and confirm the + wiring before retrying jina-v3. + +--- + +## Smoke 3 — Live GPU repair on a malformed table + +Confirms the repair loop's GPU escalation path actually invokes the +configured VLM and that the result is applied to the merged document. + +**Setup:** +- In `requirements.txt`, uncomment `transformers` (sentence-transformers + not needed for this smoke). +- Add `configs/space_gpu_repair.yaml`: + + ```yaml + parsers: + docling: + enabled: true + pymupdf: + enabled: true + repair: + enabled: true + gpu_escalation: true + execute_gpu_escalations: true # the bit that flips the live path on + gpu: + backend: transformers + models: + table: + model_id: Qwen/Qwen2.5-VL-3B-Instruct + task: table-repair + device: auto + dtype: bfloat16 + ``` + +- Set `ZSGDP_CONFIG_PATH=configs/space_gpu_repair.yaml` on the Space. + +**Trigger:** upload a PDF that contains a table the parsers will likely +mangle. A two-column financial statement page works well; if you don't +have one handy, take a Wikipedia article PDF that has a comparison table. + +**Expected log lines (in order):** +- `parse_start`. +- `parser_candidate` for docling and pymupdf (both should fire on a PDF). +- `repair_iteration` with `iteration=1`, `gpu_task_count >= 1`, + `gpu_dry_run=false`. +- One `gpu_task_executed` record per GPU task. `status` should be + `executed` and `elapsed_seconds` 1-10s for a 3B-param VLM on L4. +- A second `repair_iteration` with `iteration=2` only if iteration 1 + changed something and quality is still below threshold; otherwise the + loop terminates. +- `parse_end` with `repair_iterations >= 1`. + +**Pass criteria:** +- At least one `gpu_task_executed` with `status=executed`. +- The output `parsed_document.json` shows `parsed.tables[i].provenance.gpu_repair_task_id` set. +- No `gpu_task_blocked` records (would mean missing image_path or doc_id). + +**Failure modes to watch:** +- All `gpu_task_executed` records show `status=execution_failed` → + inspect `output.error` field; common causes are missing image_path + (the PDF doesn't render page crops because `pdf.crop_tables=true` isn't + set) or a CUDA OOM. +- No `repair_iteration` records → the verifier didn't flag any + blocking issues; pick a different input PDF. + +--- + +## Smoke 4 — Per-parser ablation across docling + pymupdf + +Confirms the ablation runner produces a comparison CSV and that each arm's +artifacts are isolated. No GPU dependency, runs on default Space hardware. + +**Setup:** default config, no requirements.txt changes. + +**Trigger:** Space JupyterLab terminal: + +```bash +zsgdp benchmark-ablate \ + --input ./fixtures/pdfs \ + --output ./out/ablation \ + --parser docling --parser pymupdf +``` + +**Expected log lines:** one parse cycle per arm (parse_start through +parse_end), three arms total (docling-only, pymupdf-only, merged). + +**Pass criteria:** +- `out/ablation/ablation_comparison.csv` has 3 rows. +- Each arm's `mean_quality_score` is non-zero. +- The merged arm's `mean_quality_score` is `>= max(per-parser arms)`. + +--- + +## Smoke 5 — External parser CLI (Marker) + +The riskiest of the four external adapters because Marker's argv schema +has changed several times. Per-Space, do not bundle with other smokes. + +**Setup:** +- Uncomment `marker-pdf` in `requirements.txt`. +- Add `configs/space_marker.yaml`: + + ```yaml + parsers: + text: + enabled: false + pymupdf: + enabled: false + marker: + enabled: true + timeout_seconds: 300 + output_args: ["--output_dir", "{output_dir}", "--output_format", "markdown"] + extra_args: [] + ``` + +- Set `ZSGDP_CONFIG_PATH=configs/space_marker.yaml`. + +**Trigger:** upload a small PDF (1–3 pages) via the Gradio UI. + +**Expected log lines:** +- `parse_start`. +- `parser_candidate` for `marker` with non-zero `element_count`. +- `parse_end` with `candidate_parsers=["marker"]`. + +**Pass criteria:** +- No `parser_failed` record for marker. +- Output Markdown has reasonable content (open the artifact zip and check). +- If `parser_failed` fires, look at `extra.error` — most common cause is + argv schema drift; tweak `output_args` in the config and retry. + +--- + +## What "deployment ready" means after this checklist + +If smokes 1–3 pass on a fresh duplicated Space, the project is genuinely +deployable for the Docling + PyMuPDF + Qwen2.5-VL-3B repair stack. Smokes 4 +and 5 are nice-to-have — the per-parser ablation works locally too, and +external parsers stay flagged "experimental" until you actively need them. + +Open the `parsed_document.json` from each smoke, copy the `quality_score`, +`mean_layout_f1` (where applicable), and any §29-relevant metric into +`README.md` under a new "Production benchmark numbers" section. That +publishes evidence that the success criteria are met against real data. diff --git a/examples/parse_folder.py b/examples/parse_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..4c9bbfdae3046a997d11e79dd940f04c204f00f4 --- /dev/null +++ b/examples/parse_folder.py @@ -0,0 +1,27 @@ +"""Parse a folder sequentially.""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +from zsgdp import parse_document + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("output") + args = parser.parse_args() + + input_dir = Path(args.input) + output_dir = Path(args.output) + output_dir.mkdir(parents=True, exist_ok=True) + for path in sorted(item for item in input_dir.iterdir() if item.is_file()): + parsed = parse_document(path, output_dir / path.stem) + print(f"{path.name}: score={parsed.quality_report.score:.2f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/parse_pdf.py b/examples/parse_pdf.py new file mode 100644 index 0000000000000000000000000000000000000000..020f4d4cbc73b84a6c25d6768029d278b2dbb49d --- /dev/null +++ b/examples/parse_pdf.py @@ -0,0 +1,25 @@ +"""Parse one PDF with the MVP pipeline.""" + +from __future__ import annotations + +import argparse + +from zsgdp import parse_document + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("output") + args = parser.parse_args() + parsed = parse_document(args.input, args.output) + print( + f"score={parsed.quality_report.score:.2f} " + f"elements={len(parsed.elements)} tables={len(parsed.tables)} " + f"figures={len(parsed.figures)} chunks={len(parsed.chunks)}" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/run_benchmark.py b/examples/run_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..34a2931a1b0677c80f4ff429932c72e2ebe48a43 --- /dev/null +++ b/examples/run_benchmark.py @@ -0,0 +1,33 @@ +"""Minimal benchmark runner placeholder.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from time import perf_counter + +from zsgdp import parse_document +from zsgdp.benchmarks.throughput import pages_per_second + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("output") + args = parser.parse_args() + + input_dir = Path(args.input) + output_dir = Path(args.output) + output_dir.mkdir(parents=True, exist_ok=True) + total_pages = 0 + started = perf_counter() + for path in sorted(item for item in input_dir.iterdir() if item.is_file()): + parsed = parse_document(path, output_dir / path.stem) + total_pages += len(parsed.pages) + elapsed = perf_counter() - started + print(f"pages={total_pages} seconds={elapsed:.2f} pages_per_second={pages_per_second(total_pages, elapsed):.2f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..ab7d68ea1c244ff098d12318f3e4ea8d2cdeaba0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "zero-shot-gpu-doc-parser" +version = "0.1.0" +description = "Zero-shot GPU document parsing and agentic chunking control plane." +readme = "README.md" +requires-python = ">=3.11" +license = { text = "MIT" } +authors = [{ name = "Zero-Shot GPU Document Parser Contributors" }] +dependencies = [] + +[project.optional-dependencies] +pdf = ["pymupdf>=1.24.0,<1.28.0"] +yaml = ["pyyaml>=6.0.1,<7.0.0"] +docling = ["docling>=2.0.0,<3.0.0"] +# `spaces` mirrors requirements.txt at the root, which is what HF Spaces +# installs verbatim. Keep these two in sync; torch is intentionally absent +# because the l4x1 Space image preinstalls a CUDA-matched build. +spaces = [ + "gradio>=4.44.0,<7.0.0", + "pymupdf>=1.24.0,<1.28.0", + "pyyaml>=6.0.1,<7.0.0", + "docling>=2.0.0,<3.0.0", +] +embedding = ["sentence-transformers>=3.0.0,<4.0.0", "transformers>=4.45.0,<6.0.0"] +gpu_repair = ["transformers>=4.45.0,<6.0.0"] +dev = ["pytest>=8.0.0"] + +[project.scripts] +zsgdp = "zsgdp.cli:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["zsgdp*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["."] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..79940b71bc0006d446a4e53758b247288e5c074a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +# Hugging Face Spaces dependencies for zeroshotGPU. +# +# Versions are pinned to tested upper bounds within each major. Bump these +# when you have run `python -m unittest discover` and the benchmark suite +# successfully against a new release. +# +# Torch is intentionally NOT pinned here. The l4x1 Space image preinstalls a +# CUDA-matched torch build; pinning torch in this file overrides it and risks +# a runtime/driver mismatch. If you're running locally without the Space +# preinstall, install torch separately via the recommended channel for your +# platform (e.g. `pip install torch --index-url https://download.pytorch.org/whl/cu121`). + +gradio>=4.44.0,<7.0.0 +pymupdf>=1.24.0,<1.28.0 +pyyaml>=6.0.1,<7.0.0 +docling>=2.0.0,<3.0.0 + +# Optional GPU/embedding stack. Uncomment to enable the embedding retriever +# (benchmarks.retriever.backend=embedding) and live GPU repair escalations +# (repair.execute_gpu_escalations=true). Both are off by default. +# +# transformers>=4.45.0,<6.0.0 +# sentence-transformers>=3.0.0,<4.0.0 + +# Optional external parser CLIs. Each adds a non-trivial install footprint; +# enable only the ones the Space hardware can support. Adapter shells out to +# the CLI binary (see zsgdp/parsers/external.py); these adapters have not +# been smoke-tested against a live install — verify the argv schema before +# enabling in production. +# +# marker-pdf>=1.0.0 +# mineru +# unstructured>=0.15.0 diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/run_space_smoke.py b/scripts/run_space_smoke.py new file mode 100644 index 0000000000000000000000000000000000000000..dbbcb59d866b8dfd7f1af2d0206ae7f3c8e099de --- /dev/null +++ b/scripts/run_space_smoke.py @@ -0,0 +1,455 @@ +"""Space-side smoke validation runner. + +Automates the smokes documented in docs/space_smoke.md so a Space operator +can run one command and get a JSON report of which smokes passed, which +were skipped (missing deps), and which failed (with diagnostic context). + +Usage: + + # Run all smokes that have their deps installed: + python -m scripts.run_space_smoke --output ./space_smoke_report.json + + # Run only a subset: + python -m scripts.run_space_smoke --smoke lexical --smoke ablation + + # Force-fail on skipped smokes (CI-style strict mode): + python -m scripts.run_space_smoke --strict + +The runner does NOT install missing dependencies — that's deliberately the +operator's job (each smoke's deps add Space build time and download cost). +A skipped smoke prints the exact `pip install` line you'd need. + +Smokes mirror docs/space_smoke.md: + + lexical - model-free benchmark on a synthetic markdown corpus + ablation - per-parser ablation runner (text vs pymupdf) + embedding - sentence-transformers / jina-embeddings-v3 retrieval + gpu_repair - live Qwen2.5-VL invocation against a malformed table + marker - shell out to marker_single on a small PDF (if installed) +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +import shutil +import subprocess +import sys +import tempfile +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +@dataclass(slots=True) +class SmokeResult: + name: str + status: str # "pass" | "fail" | "skip" | "error" + elapsed_seconds: float = 0.0 + detail: dict[str, Any] = field(default_factory=dict) + skip_reason: str = "" + install_hint: str = "" + + +@dataclass(slots=True) +class SmokeReport: + smokes: list[SmokeResult] = field(default_factory=list) + + @property + def passed(self) -> bool: + return all(item.status in {"pass", "skip"} for item in self.smokes) + + def to_dict(self) -> dict[str, Any]: + return { + "smokes": [ + { + "name": item.name, + "status": item.status, + "elapsed_seconds": round(item.elapsed_seconds, 3), + "detail": item.detail, + "skip_reason": item.skip_reason, + "install_hint": item.install_hint, + } + for item in self.smokes + ], + "summary": { + "total": len(self.smokes), + "passed": sum(1 for item in self.smokes if item.status == "pass"), + "failed": sum(1 for item in self.smokes if item.status == "fail"), + "errored": sum(1 for item in self.smokes if item.status == "error"), + "skipped": sum(1 for item in self.smokes if item.status == "skip"), + }, + } + + +# --- Individual smokes ------------------------------------------------------- + + +def _make_distinctive_corpus(root: Path) -> Path: + """Build a small corpus with three sentences distinct enough that the + synthetic-QA generator picks one query per chunk.""" + + src = root / "in" + src.mkdir() + (src / "doc.md").write_text( + "# Sample Doc\n\n" + "Apples grow on trees in the orchard during autumn harvest season.\n\n" + "Submarines navigate beneath the ocean using sonar pulses across waters.\n\n" + "Mountains rise above the clouds in the distant horizon line.\n", + encoding="utf-8", + ) + return src + + +def smoke_lexical() -> SmokeResult: + started = time.perf_counter() + from zsgdp.benchmarks.parser_quality import run_parser_benchmark + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + src = _make_distinctive_corpus(tmp_path) + out = tmp_path / "out" + try: + summary = run_parser_benchmark(src, out, dataset_name="custom_folder") + except Exception as exc: + return SmokeResult( + name="lexical", + status="error", + elapsed_seconds=time.perf_counter() - started, + detail={"exception": str(exc)}, + ) + + quality = float(summary.get("mean_quality_score", 0.0)) + recall = float(summary.get("mean_retrieval_recall_at_1", 0.0)) + passed = quality >= 0.85 and recall >= 0.7 + return SmokeResult( + name="lexical", + status="pass" if passed else "fail", + elapsed_seconds=time.perf_counter() - started, + detail={ + "mean_quality_score": quality, + "mean_retrieval_recall_at_1": recall, + "documents_evaluated": summary.get("document_count"), + }, + ) + + +def smoke_ablation() -> SmokeResult: + started = time.perf_counter() + from zsgdp.benchmarks.ablation_runner import run_parser_ablations + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + src = _make_distinctive_corpus(tmp_path) + out = tmp_path / "out" + try: + comparison = run_parser_ablations( + src, + out, + parsers=["text", "pymupdf"], + dataset_name="custom_folder", + ) + except Exception as exc: + return SmokeResult( + name="ablation", + status="error", + elapsed_seconds=time.perf_counter() - started, + detail={"exception": str(exc)}, + ) + + comparison_csv_exists = (out / "ablation_comparison.csv").exists() + + arms = [row["arm"] for row in comparison["rows"]] + expected_arms = {"text", "pymupdf", "merged"} + passed = comparison["arm_count"] == 3 and set(arms) == expected_arms and comparison_csv_exists + return SmokeResult( + name="ablation", + status="pass" if passed else "fail", + elapsed_seconds=time.perf_counter() - started, + detail={ + "arm_count": comparison["arm_count"], + "arms": arms, + "comparison_csv_emitted": comparison_csv_exists, + }, + ) + + +def smoke_embedding() -> SmokeResult: + started = time.perf_counter() + if importlib.util.find_spec("sentence_transformers") is None: + return SmokeResult( + name="embedding", + status="skip", + elapsed_seconds=time.perf_counter() - started, + skip_reason="sentence-transformers not installed", + install_hint="python -m pip install 'zero-shot-gpu-doc-parser[embedding]'", + ) + + from zsgdp.benchmarks.embedding_retriever import EmbeddingRetriever + from zsgdp.benchmarks.parser_quality import run_parser_benchmark + + # Try to load the configured embedding model. If the load fails (no HF + # token, download error, OOM at import time), we report it as a skip + # with the exception text so the operator sees what to fix without the + # whole smoke run blowing up. + try: + retriever = EmbeddingRetriever() + retriever._ensure_embedder() # type: ignore[attr-defined] # private but intentional + except Exception as exc: + return SmokeResult( + name="embedding", + status="skip", + elapsed_seconds=time.perf_counter() - started, + skip_reason=f"embedding model failed to load: {exc}", + install_hint="Set HF_TOKEN if the model is gated, or downsize via " + "benchmarks.retriever.model_id (e.g. sentence-transformers/all-MiniLM-L6-v2).", + ) + + config_overrides = {"benchmarks": {"retriever": {"backend": "embedding"}}} + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + src = _make_distinctive_corpus(tmp_path) + out = tmp_path / "out" + config_path = tmp_path / "config.yaml" + # Inline config write — keeps the smoke self-contained. + config_path.write_text( + "benchmarks:\n retriever:\n backend: embedding\n", + encoding="utf-8", + ) + try: + summary = run_parser_benchmark(src, out, config_path=config_path, dataset_name="custom_folder") + except Exception as exc: + return SmokeResult( + name="embedding", + status="error", + elapsed_seconds=time.perf_counter() - started, + detail={"exception": str(exc)}, + ) + + recall_5 = float(summary.get("mean_retrieval_recall_at_5", 0.0)) + passed = recall_5 >= 0.7 + return SmokeResult( + name="embedding", + status="pass" if passed else "fail", + elapsed_seconds=time.perf_counter() - started, + detail={ + "mean_retrieval_recall_at_5": recall_5, + "mean_retrieval_recall_at_1": float(summary.get("mean_retrieval_recall_at_1", 0.0)), + "documents_evaluated": summary.get("document_count"), + }, + ) + + +def smoke_gpu_repair() -> SmokeResult: + started = time.perf_counter() + if importlib.util.find_spec("transformers") is None: + return SmokeResult( + name="gpu_repair", + status="skip", + elapsed_seconds=time.perf_counter() - started, + skip_reason="transformers not installed", + install_hint="python -m pip install 'zero-shot-gpu-doc-parser[gpu_repair]'", + ) + + # Don't actually instantiate the transformers pipeline here — it would + # download multi-GB Qwen2.5-VL weights even on a dry probe. Instead, we + # smoke-test the wiring: a dry-run task plan, and report whether the + # underlying client class can be imported. Operators who want a real + # model invocation should use `run-gpu-tasks --execute` against a parsed + # output directory; the result lands in repair.gpu_escalation.results. + from zsgdp.gpu.transformers_client import TransformersClient + from zsgdp.pipeline import parse_document + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + src = tmp_path / "report.md" + # Malformed table (header has 2 columns; data row has 3) forces the + # repair loop to plan a table_vlm_repair task. + src.write_text( + "# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 | 3 |\n", + encoding="utf-8", + ) + out = tmp_path / "out" + try: + parsed = parse_document(src, out) + except Exception as exc: + return SmokeResult( + name="gpu_repair", + status="error", + elapsed_seconds=time.perf_counter() - started, + detail={"exception": str(exc)}, + ) + + repair = parsed.provenance.get("repair", {}) + gpu_escalation = repair.get("gpu_escalation") or {} + task_count = int(gpu_escalation.get("task_count") or 0) + iterations = parsed.provenance.get("repair_iterations") or [] + # We can confirm: + # * Dry-run plan ran (task_count >= 1 for the malformed table) + # * The repair loop iterated at least once + # * The TransformersClient class is importable for live execution + can_execute = TransformersClient is not None + passed = task_count >= 1 and len(iterations) >= 1 and can_execute + return SmokeResult( + name="gpu_repair", + status="pass" if passed else "fail", + elapsed_seconds=time.perf_counter() - started, + detail={ + "dry_run_task_count": task_count, + "repair_iterations": len(iterations), + "transformers_client_importable": can_execute, + "note": "This smoke verifies wiring only. To verify model invocation " + "end-to-end, set repair.execute_gpu_escalations=true in config " + "and run zsgdp run-gpu-tasks --execute against a parsed dir.", + }, + ) + + +def smoke_marker() -> SmokeResult: + started = time.perf_counter() + if shutil.which("marker_single") is None and shutil.which("marker") is None: + return SmokeResult( + name="marker", + status="skip", + elapsed_seconds=time.perf_counter() - started, + skip_reason="neither `marker_single` nor `marker` found on PATH", + install_hint="python -m pip install marker-pdf", + ) + + # Marker is heavy enough that even a probe call can take 30+s on first + # invocation (model load). We confirm the registry adapter reports + # available, but don't run a full parse here — surface that as a manual + # follow-up via the smoke checklist. + from zsgdp.parsers.registry import get_parser + + try: + adapter = get_parser("marker") + except KeyError as exc: + return SmokeResult( + name="marker", + status="error", + elapsed_seconds=time.perf_counter() - started, + detail={"exception": str(exc)}, + ) + available = bool(adapter.available()) + return SmokeResult( + name="marker", + status="pass" if available else "fail", + elapsed_seconds=time.perf_counter() - started, + detail={ + "adapter_reports_available": available, + "note": "End-to-end Marker parse is intentionally not run here " + "(cold-load is heavy). See docs/space_smoke.md Smoke 5 " + "for the manual upload-and-parse procedure.", + }, + ) + + +SMOKE_REGISTRY: dict[str, Callable[[], SmokeResult]] = { + "lexical": smoke_lexical, + "ablation": smoke_ablation, + "embedding": smoke_embedding, + "gpu_repair": smoke_gpu_repair, + "marker": smoke_marker, +} + + +# --- Driver ------------------------------------------------------------------ + + +def run_smokes(names: list[str] | None = None) -> SmokeReport: + selected = names or list(SMOKE_REGISTRY) + report = SmokeReport() + for name in selected: + smoke = SMOKE_REGISTRY.get(name) + if smoke is None: + report.smokes.append( + SmokeResult( + name=name, + status="error", + detail={"exception": f"unknown smoke: {name}"}, + ) + ) + continue + try: + result = smoke() + except Exception as exc: + result = SmokeResult( + name=name, + status="error", + detail={"exception": f"{type(exc).__name__}: {exc}"}, + ) + report.smokes.append(result) + return report + + +def format_text_summary(report: SmokeReport, *, strict: bool = False) -> str: + lines: list[str] = [] + for item in report.smokes: + marker = { + "pass": "ok", + "fail": "FAIL", + "skip": "skip", + "error": "ERROR", + }.get(item.status, item.status.upper()) + line = f" [{marker}] {item.name} ({item.elapsed_seconds:.2f}s)" + if item.status == "skip": + line += f" reason={item.skip_reason}" + elif item.status == "fail": + line += f" detail={json.dumps(item.detail, default=str)}" + elif item.status == "error": + line += f" detail={json.dumps(item.detail, default=str)}" + lines.append(line) + + summary = report.to_dict()["summary"] + overall = "PASS" if (report.passed and (not strict or summary["skipped"] == 0)) else "FAIL" + lines.append( + f"smoke: {overall} passed={summary['passed']} failed={summary['failed']} " + f"errored={summary['errored']} skipped={summary['skipped']}" + ) + return "\n".join(lines) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + prog="run_space_smoke", + description="Run zsgdp Space-side smoke validations.", + ) + parser.add_argument( + "--smoke", + action="append", + dest="smokes", + choices=list(SMOKE_REGISTRY), + help="Smoke to run. Repeat to run multiple. Default: all registered smokes.", + ) + parser.add_argument("--output", help="Optional JSON report path.") + parser.add_argument( + "--strict", + action="store_true", + help="Treat skipped smokes as failures (useful in CI when all deps must be present).", + ) + args = parser.parse_args(argv) + + report = run_smokes(args.smokes) + print(format_text_summary(report, strict=args.strict)) + + if args.output: + Path(args.output).write_text( + json.dumps(report.to_dict(), indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + + summary = report.to_dict()["summary"] + if summary["failed"] or summary["errored"]: + return 1 + if args.strict and summary["skipped"]: + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38bb211b065081810c8c65af2d44ca10dc9fb059 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package.""" diff --git a/tests/regression/README.md b/tests/regression/README.md new file mode 100644 index 0000000000000000000000000000000000000000..44719370a8c179f0e0a54cef7ff4caa743e6c848 --- /dev/null +++ b/tests/regression/README.md @@ -0,0 +1,97 @@ +# Regression fixtures + +Each fixture is a `(.input., .expected.json)` pair under +`fixtures/`. The runner in `test_regression.py` parses every input through +`parse_document` and compares the resulting `ParsedDocument` against the +snapshot in `.expected.json` with explicit tolerances. + +## Fixture file shape + +`.expected.json` has these keys (all optional except `name`): + +```json +{ + "name": "human-readable identifier", + "config": "configs/docling.yaml", + "selected_parsers": ["text"], + "tolerances": { + "quality_score_min": 0.85, + "element_count_range": [3, 6], + "table_count": 1, + "figure_count": 0, + "chunk_count_min": 1, + "blocking_failures": false, + "must_contain_markdown": ["# Report", "Apples grow"], + "must_not_contain_markdown": ["TODO", "FIXME"] + } +} +``` + +Tolerance keys (all optional): + +- `quality_score_min` (float): assert `parsed.quality_report.score >= value`. +- `quality_score_max` (float): assert `parsed.quality_report.score <= value`. +- `element_count` (int) or `element_count_range` ([min, max]). +- `table_count` (int) or `table_count_range`. +- `figure_count` (int) or `figure_count_range`. +- `chunk_count_min` (int): assert at least N chunks. +- `chunk_count_max` (int): assert at most N chunks. +- `blocking_failures` (bool): assert `quality_report.has_blocking_failures` matches. +- `must_contain_markdown` (list[str]): each string must appear in + `parsed.to_markdown()`. +- `must_not_contain_markdown` (list[str]): each string must NOT appear. +- `must_contain_quality_metrics` (list[str]): each metric key must appear in + `quality_report.metrics`. +- `parser_disagreement_rate_max` (float): assert disagreement <= value. +- `repair_resolution_rate_min` (float): assert resolution >= value. + +Missing keys are not asserted (no false failures from over-specification). + +## Adding a fixture + +1. Drop the input document under `fixtures/`. PDFs, markdown, html, txt all + work via the standard pipeline. +2. Run a one-off `parse_document` against it locally and inspect the output. +3. Hand-write `.expected.json` with the constraints you want to lock + down. Prefer ranges over exact counts where reasonable variance exists. +4. Run `python3.11 -m unittest tests.test_regression`. It auto-discovers. + +## Performance baselines (opt-in) + +A fixture may include a `performance` block with throughput floors: + +```json +{ + "performance": { + "repeats": 2, + "max_elapsed_seconds": 2.0, + "min_pages_per_second": 0.5, + "always_enforce": false + } +} +``` + +Keys: + +- `repeats` (int, default 2): number of warm parses to time. The median + elapsed is compared against the floor so a single cold-import outlier + does not flag. +- `max_elapsed_seconds`: parse must finish under this in median. +- `min_pages_per_second`: median pages/sec must meet or beat this. +- `always_enforce` (bool, default false): when true, perf is always checked. + +Otherwise perf is gated on `ZSGDP_REGRESSION_PERF=1` so slow CI runners +don't get noisy. Floors should be **catastrophic-regression guards** — set +them ~50–100x slacker than your local median, not tight perf bars. The +point is to catch "parsing a tiny markdown doc now takes 30 seconds," +not to track 5 % perf shifts. + +To set a baseline for a new fixture: parse it 5 times locally, take the +median, multiply by ~10–80x for the `max_elapsed_seconds` floor. + +## When a regression fires + +The failure message points at the specific tolerance that broke. Don't blindly +loosen the tolerance — investigate whether the regression is real first +(parser-version bump, repair-loop drift, chunk planner change). If the new +behavior is intentional and better, regenerate the snapshot. diff --git a/tests/regression/__init__.py b/tests/regression/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/regression/fixtures/markdown_basic.expected.json b/tests/regression/fixtures/markdown_basic.expected.json new file mode 100644 index 0000000000000000000000000000000000000000..3c1b5710c9fec39359acf6bbb1ec360ccfa6be29 --- /dev/null +++ b/tests/regression/fixtures/markdown_basic.expected.json @@ -0,0 +1,31 @@ +{ + "name": "markdown_basic", + "tolerances": { + "quality_score_min": 0.9, + "blocking_failures": false, + "element_count_range": [4, 8], + "table_count": 1, + "figure_count": 0, + "chunk_count_min": 4, + "must_contain_markdown": [ + "# Quarterly Report", + "Apples grow on trees in the orchard", + "| Region | Q1 | Q2 |", + "Submarines navigate beneath the ocean" + ], + "must_not_contain_markdown": ["TODO", "FIXME"], + "must_contain_quality_metrics": [ + "document_text_coverage", + "parser_disagreement_rate", + "repair_resolution_rate" + ], + "parser_disagreement_rate_max": 0.5, + "repair_resolution_rate_min": 0.5 + }, + "performance": { + "_comment": "Floors are catastrophic-regression guards, not tight perf bars. Median of 2 warm runs (cold-import outlier dropped) was ~6ms locally; the floor is 80x that to absorb slow CI. Enable with ZSGDP_REGRESSION_PERF=1 or set always_enforce: true.", + "repeats": 2, + "max_elapsed_seconds": 2.0, + "min_pages_per_second": 0.5 + } +} diff --git a/tests/regression/fixtures/markdown_basic.input.md b/tests/regression/fixtures/markdown_basic.input.md new file mode 100644 index 0000000000000000000000000000000000000000..c12f03ddf8977f184f1448a6b471443896409eb2 --- /dev/null +++ b/tests/regression/fixtures/markdown_basic.input.md @@ -0,0 +1,14 @@ +# Quarterly Report + +Apples grow on trees in the orchard during the autumn harvest season. + +## Revenue + +| Region | Q1 | Q2 | +| --- | --- | --- | +| North America | 10 | 12 | +| Europe | 8 | 9 | + +## Outlook + +Submarines navigate beneath the ocean using sonar pulses across waters. diff --git a/tests/regression/test_regression.py b/tests/regression/test_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8291195e2be0a7a25aab5bf7b0082eb5f8ace0 --- /dev/null +++ b/tests/regression/test_regression.py @@ -0,0 +1,255 @@ +"""Snapshot regression tests against fixtures in this directory. + +Discovery: every .expected.json under fixtures/ pairs with a sibling +.input.. The runner parses the input, then asserts each tolerance +in the expected file. Tolerance keys are documented in fixtures/README.md. + +Performance baselines are opt-in per fixture via a `performance` block in +the expected file. They run only when ZSGDP_REGRESSION_PERF=1 (or when the +performance block has `always_enforce: true`) so a slow CI runner does not +fail on transient noise. When enabled, the parse is run twice and the +median elapsed time is compared against the floor. +""" + +from __future__ import annotations + +import json +import os +import statistics +import tempfile +import time +import unittest +import unittest.mock +from pathlib import Path +from typing import Any + +from zsgdp.pipeline import parse_document + +FIXTURE_DIR = Path(__file__).parent / "fixtures" + + +def _discover_fixtures() -> list[tuple[str, Path, Path]]: + pairs: list[tuple[str, Path, Path]] = [] + if not FIXTURE_DIR.exists(): + return pairs + for expected in sorted(FIXTURE_DIR.glob("*.expected.json")): + name = expected.name[: -len(".expected.json")] + candidates = sorted(FIXTURE_DIR.glob(f"{name}.input.*")) + if not candidates: + continue + pairs.append((name, candidates[0], expected)) + return pairs + + +def _check_int_or_range(actual: int, exact: Any, range_value: Any, label: str) -> str | None: + if exact is not None and int(exact) != actual: + return f"{label}: expected {exact}, got {actual}" + if isinstance(range_value, (list, tuple)) and len(range_value) == 2: + lo, hi = int(range_value[0]), int(range_value[1]) + if not (lo <= actual <= hi): + return f"{label}: expected in [{lo}, {hi}], got {actual}" + return None + + +def _evaluate(parsed, tolerances: dict[str, Any]) -> list[str]: + failures: list[str] = [] + score = float(parsed.quality_report.score) + if "quality_score_min" in tolerances and score < float(tolerances["quality_score_min"]): + failures.append(f"quality_score: {score:.3f} < {tolerances['quality_score_min']}") + if "quality_score_max" in tolerances and score > float(tolerances["quality_score_max"]): + failures.append(f"quality_score: {score:.3f} > {tolerances['quality_score_max']}") + + for label, count, exact_key, range_key in ( + ("element_count", len(parsed.elements), "element_count", "element_count_range"), + ("table_count", len(parsed.tables), "table_count", "table_count_range"), + ("figure_count", len(parsed.figures), "figure_count", "figure_count_range"), + ): + message = _check_int_or_range(count, tolerances.get(exact_key), tolerances.get(range_key), label) + if message: + failures.append(message) + + chunk_count = len(parsed.chunks) + if "chunk_count_min" in tolerances and chunk_count < int(tolerances["chunk_count_min"]): + failures.append(f"chunk_count: {chunk_count} < {tolerances['chunk_count_min']}") + if "chunk_count_max" in tolerances and chunk_count > int(tolerances["chunk_count_max"]): + failures.append(f"chunk_count: {chunk_count} > {tolerances['chunk_count_max']}") + + if "blocking_failures" in tolerances: + actual = parsed.quality_report.has_blocking_failures + expected = bool(tolerances["blocking_failures"]) + if actual != expected: + failures.append(f"blocking_failures: expected {expected}, got {actual}") + + md = parsed.to_markdown() + for needle in tolerances.get("must_contain_markdown", []) or []: + if str(needle) not in md: + failures.append(f"must_contain_markdown: {needle!r} not found") + for needle in tolerances.get("must_not_contain_markdown", []) or []: + if str(needle) in md: + failures.append(f"must_not_contain_markdown: {needle!r} present") + + metrics = parsed.quality_report.metrics + for key in tolerances.get("must_contain_quality_metrics", []) or []: + if key not in metrics: + failures.append(f"must_contain_quality_metrics: {key!r} missing") + + if "parser_disagreement_rate_max" in tolerances: + rate = float(metrics.get("parser_disagreement_rate", 0.0)) + if rate > float(tolerances["parser_disagreement_rate_max"]): + failures.append( + f"parser_disagreement_rate: {rate:.3f} > {tolerances['parser_disagreement_rate_max']}" + ) + if "repair_resolution_rate_min" in tolerances: + rate = float(metrics.get("repair_resolution_rate", 1.0)) + if rate < float(tolerances["repair_resolution_rate_min"]): + failures.append( + f"repair_resolution_rate: {rate:.3f} < {tolerances['repair_resolution_rate_min']}" + ) + + return failures + + +def _perf_enforcement_enabled(performance: dict[str, Any]) -> bool: + if performance.get("always_enforce"): + return True + return os.environ.get("ZSGDP_REGRESSION_PERF", "").strip().lower() in {"1", "true", "yes"} + + +def _measure_parse(input_path: Path, *, config_path: Path | None, selected_parsers, repeats: int) -> tuple[Any, list[float]]: + """Parse the input N times, returning (last_parsed, list_of_elapsed_seconds). + + Uses a fresh temp output directory for each run so disk caching effects + are roughly equal across runs. The last parsed document is returned for + tolerance evaluation; per-run elapsed times feed the perf assertion. + """ + + elapsed: list[float] = [] + parsed = None + for _ in range(max(1, repeats)): + with tempfile.TemporaryDirectory() as tmp: + started = time.perf_counter() + parsed = parse_document( + input_path, + Path(tmp) / "out", + config_path=config_path if config_path else None, + selected_parsers=selected_parsers, + ) + elapsed.append(time.perf_counter() - started) + return parsed, elapsed + + +def _evaluate_performance(parsed, performance: dict[str, Any], elapsed_seconds: list[float]) -> list[str]: + failures: list[str] = [] + if not elapsed_seconds: + return failures + + median_elapsed = statistics.median(elapsed_seconds) + page_count = max(len(parsed.pages), 1) + median_pages_per_second = page_count / median_elapsed if median_elapsed > 0 else float("inf") + + max_elapsed = performance.get("max_elapsed_seconds") + if max_elapsed is not None and median_elapsed > float(max_elapsed): + failures.append( + f"performance.max_elapsed_seconds: median {median_elapsed:.2f}s > {max_elapsed}s " + f"(runs={len(elapsed_seconds)})" + ) + + min_pps = performance.get("min_pages_per_second") + if min_pps is not None and median_pages_per_second < float(min_pps): + failures.append( + f"performance.min_pages_per_second: median {median_pages_per_second:.2f} < {min_pps} " + f"(runs={len(elapsed_seconds)})" + ) + + return failures + + +class RegressionFixturesTest(unittest.TestCase): + def test_regression_fixtures_match_snapshots(self): + fixtures = _discover_fixtures() + if not fixtures: + self.skipTest("No regression fixtures present.") + + all_failures: list[str] = [] + for name, input_path, expected_path in fixtures: + with self.subTest(fixture=name): + expected = json.loads(expected_path.read_text(encoding="utf-8")) + tolerances = expected.get("tolerances") or {} + performance = expected.get("performance") or {} + config_rel = expected.get("config") + config_path = Path(config_rel) if config_rel else None + if config_path and not config_path.is_absolute(): + config_path = Path(__file__).resolve().parents[2] / config_path + selected_parsers = expected.get("selected_parsers") + + perf_enabled = bool(performance) and _perf_enforcement_enabled(performance) + repeats = int(performance.get("repeats", 2)) if perf_enabled else 1 + + parsed, elapsed = _measure_parse( + input_path, + config_path=config_path, + selected_parsers=selected_parsers, + repeats=repeats, + ) + + failures = _evaluate(parsed, tolerances) + if perf_enabled: + failures.extend(_evaluate_performance(parsed, performance, elapsed)) + if failures: + all_failures.append(f"[{name}] " + "; ".join(failures)) + + if all_failures: + self.fail("\n".join(all_failures)) + + +class PerformanceEvaluatorTests(unittest.TestCase): + """Unit tests for the perf-evaluation helpers, separate from fixture discovery.""" + + def test_max_elapsed_floor_fires_when_too_slow(self): + from types import SimpleNamespace + + parsed = SimpleNamespace(pages=[{"page_num": 1}]) + failures = _evaluate_performance(parsed, {"max_elapsed_seconds": 0.1}, [0.5, 0.5]) + self.assertEqual(len(failures), 1) + self.assertIn("max_elapsed_seconds", failures[0]) + + def test_min_pages_per_second_fires_when_too_slow(self): + from types import SimpleNamespace + + parsed = SimpleNamespace(pages=[{"page_num": 1}]) + # 1 page in 10s => 0.1 pps, floor 1.0 => fail. + failures = _evaluate_performance(parsed, {"min_pages_per_second": 1.0}, [10.0, 10.0]) + self.assertEqual(len(failures), 1) + self.assertIn("min_pages_per_second", failures[0]) + + def test_passing_floors_yield_no_failures(self): + from types import SimpleNamespace + + parsed = SimpleNamespace(pages=[{"page_num": 1}, {"page_num": 2}]) + # 2 pages in 0.5s => 4 pps; floor 1.0 pps and max 2s. + failures = _evaluate_performance( + parsed, + {"max_elapsed_seconds": 2.0, "min_pages_per_second": 1.0}, + [0.5, 0.5, 0.5], + ) + self.assertEqual(failures, []) + + def test_median_strips_cold_outlier(self): + from types import SimpleNamespace + + parsed = SimpleNamespace(pages=[{"page_num": 1}]) + # First run cold (5s), next two warm (0.1s). Median = 0.1s; floor 1s passes. + failures = _evaluate_performance(parsed, {"max_elapsed_seconds": 1.0}, [5.0, 0.1, 0.1]) + self.assertEqual(failures, []) + + def test_perf_enforcement_gating(self): + with unittest.mock.patch.dict("os.environ", {"ZSGDP_REGRESSION_PERF": "0"}, clear=False): + self.assertFalse(_perf_enforcement_enabled({"max_elapsed_seconds": 1.0})) + self.assertTrue(_perf_enforcement_enabled({"always_enforce": True})) + + with unittest.mock.patch.dict("os.environ", {"ZSGDP_REGRESSION_PERF": "1"}, clear=False): + self.assertTrue(_perf_enforcement_enabled({"max_elapsed_seconds": 1.0})) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ablation_runner.py b/tests/test_ablation_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..4f5e73a36790ae5555570f8e13cdd7d18bec541d --- /dev/null +++ b/tests/test_ablation_runner.py @@ -0,0 +1,133 @@ +"""Tests for parser-contribution metrics and the ablation runner.""" + +from __future__ import annotations + +import json +import tempfile +import unittest +from pathlib import Path + +from zsgdp.benchmarks.ablation_runner import ABLATION_METRIC_KEYS, run_parser_ablations +from zsgdp.benchmarks.parser_quality import run_parser_benchmark + + +class TestParserContribution(unittest.TestCase): + def test_contribution_counts_appear_in_summary(self): + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + (src / "doc.md").write_text("# Doc\n\nA paragraph.\n", encoding="utf-8") + + summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder") + + doc = summary["documents"][0] + self.assertIn("parser_contribution_counts", doc) + self.assertIn("parser_contribution_fractions", doc) + self.assertGreater(sum(doc["parser_contribution_counts"].values()), 0) + # The sum of fractions should be ~1.0 across parsers. + total_fraction = sum(doc["parser_contribution_fractions"].values()) + self.assertAlmostEqual(total_fraction, 1.0, places=6) + + top_summary = summary["parser_contribution_summary"] + self.assertGreater(top_summary["total"], 0) + self.assertEqual(set(top_summary["counts"]), set(top_summary["fractions"])) + + def test_text_parser_dominates_markdown_doc(self): + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + (src / "doc.md").write_text("# Doc\n\nPara one.\n\nPara two.\n", encoding="utf-8") + + summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder") + + top_counts = summary["parser_contribution_summary"]["counts"] + self.assertIn("text", top_counts) + text_count = top_counts["text"] + other_count = sum(value for parser, value in top_counts.items() if parser != "text") + self.assertGreaterEqual(text_count, other_count) + + +class TestRunParserAblations(unittest.TestCase): + def test_two_arms_plus_merged(self): + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + (src / "doc.md").write_text("# Doc\n\nPara one.\n\nPara two.\n", encoding="utf-8") + out = tmp / "out" + + comparison = run_parser_ablations( + src, + out, + parsers=["text", "pymupdf"], + dataset_name="custom_folder", + ) + + self.assertEqual(comparison["arm_count"], 3) + arms = sorted(row["arm"] for row in comparison["rows"]) + self.assertEqual(arms, ["merged", "pymupdf", "text"]) + self.assertTrue((out / "arm_text").exists()) + self.assertTrue((out / "arm_pymupdf").exists()) + self.assertTrue((out / "arm_merged").exists()) + self.assertTrue((out / "ablation_comparison.csv").exists()) + self.assertTrue((out / "ablation_summary.json").exists()) + + # Each arm record carries the canonical metric keys (subset of those present). + for row in comparison["rows"]: + self.assertIn("mean_quality_score", row) + + def test_no_merged_when_disabled(self): + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + (src / "doc.md").write_text("# Doc\n\nPara.\n", encoding="utf-8") + + comparison = run_parser_ablations( + src, + tmp / "out", + parsers=["text", "pymupdf"], + dataset_name="custom_folder", + include_merged=False, + ) + self.assertEqual(comparison["arm_count"], 2) + self.assertNotIn("merged", {row["arm"] for row in comparison["rows"]}) + + def test_single_parser_ablation_skips_merged_arm(self): + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + (src / "doc.md").write_text("# Doc\n\nPara.\n", encoding="utf-8") + + comparison = run_parser_ablations( + src, + tmp / "out", + parsers=["text"], + dataset_name="custom_folder", + ) + # Single parser + include_merged defaults true, but len(parsers) == 1 + # so merged would be redundant and is skipped. + self.assertEqual(comparison["arm_count"], 1) + self.assertEqual(comparison["rows"][0]["arm"], "text") + + def test_empty_parsers_raises(self): + with self.assertRaises(ValueError): + run_parser_ablations(".", "./out", parsers=[]) + + def test_metric_keys_constant_matches_summary_shape(self): + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + (src / "doc.md").write_text("# Doc\n\nPara.\n", encoding="utf-8") + + summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder") + for key in ABLATION_METRIC_KEYS: + self.assertIn(key, summary, f"benchmark summary missing key {key}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_app.py b/tests/test_app.py new file mode 100644 index 0000000000000000000000000000000000000000..4e152b96e90ff58ca3ab7c05e92394a16ea6d519 --- /dev/null +++ b/tests/test_app.py @@ -0,0 +1,141 @@ +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +try: + import app as space_app +except RuntimeError as exc: + space_app = None + APP_IMPORT_ERROR = str(exc) +else: + APP_IMPORT_ERROR = "" + + +class _UploadedFile: + def __init__(self, name: str): + self.name = name + + +class AppTests(unittest.TestCase): + def test_parse_uploaded_document_returns_artifact_validation(self): + if space_app is None: + self.skipTest(APP_IMPORT_ERROR) + + with tempfile.TemporaryDirectory() as tmp: + input_path = Path(tmp) / "sample.md" + input_path.write_text("# Report\n\nHello from the Space UI.\n", encoding="utf-8") + + outputs = space_app.parse_uploaded_document(_UploadedFile(str(input_path)), "Default lightweight") + + self.assertEqual(len(outputs), 11) + summary = outputs[1] + artifact_validation = outputs[8] + archive_path = outputs[9] + individual_files = outputs[10] + self.assertTrue(summary["artifact_manifest_valid"]) + self.assertTrue(artifact_validation["valid"]) + self.assertTrue(Path(archive_path).exists()) + # Per-artifact downloads. + self.assertIsInstance(individual_files, list) + self.assertGreater(len(individual_files), 0) + names = [Path(p).name for p in individual_files] + # Core artifacts every parse should produce. + for required in ("parsed_document.json", "document.md", "chunks.jsonl", "artifact_manifest.json"): + self.assertIn(required, names) + # Each path actually exists on disk so Gradio can serve it. + for path in individual_files: + self.assertTrue(Path(path).exists(), f"missing: {path}") + # The archive zip is a separate artifact and must NOT appear in the + # per-artifact list (zip is the bundled-everything view). + self.assertNotIn(Path(archive_path).name, names) + # Summary records the per-artifact count. + self.assertEqual(summary["individual_artifact_count"], len(individual_files)) + + +class UploadGuardTests(unittest.TestCase): + def test_oversized_upload_rejected_with_clear_message(self): + if space_app is None: + self.skipTest(APP_IMPORT_ERROR) + + with tempfile.TemporaryDirectory() as tmp: + input_path = Path(tmp) / "huge.md" + input_path.write_text("# Big\n\n" + "x" * 4096, encoding="utf-8") + + with patch.object(space_app, "MAX_UPLOAD_BYTES", 1024): + outputs = space_app.parse_uploaded_document( + _UploadedFile(str(input_path)), "Default lightweight" + ) + + summary = outputs[1] + self.assertTrue(summary.get("rejected")) + self.assertIn("MB", summary["error"]) + + def test_high_page_count_rejected(self): + if space_app is None: + self.skipTest(APP_IMPORT_ERROR) + + with tempfile.TemporaryDirectory() as tmp: + input_path = Path(tmp) / "doc.md" + input_path.write_text("# Doc\n\nSomething small.\n", encoding="utf-8") + + class _FakeProfile: + page_count = 1000 + + with patch.object(space_app, "MAX_PAGE_COUNT", 50), patch.object( + space_app, "profile_document", return_value=_FakeProfile() + ): + outputs = space_app.parse_uploaded_document( + _UploadedFile(str(input_path)), "Default lightweight" + ) + + summary = outputs[1] + self.assertTrue(summary.get("rejected")) + self.assertIn("pages", summary["error"]) + + def test_missing_upload_path_rejected(self): + if space_app is None: + self.skipTest(APP_IMPORT_ERROR) + + outputs = space_app.parse_uploaded_document( + _UploadedFile("/tmp/zsgdp-does-not-exist.md"), "Default lightweight" + ) + summary = outputs[1] + self.assertTrue(summary.get("rejected")) + self.assertIn("missing", summary["error"].lower()) + + def test_error_paths_return_full_tuple_width(self): + # Drift guard: every return path (success + error) must yield 11 outputs + # so the Gradio click handler doesn't error on shape mismatch. + if space_app is None: + self.skipTest(APP_IMPORT_ERROR) + + # No upload at all. + outputs = space_app.parse_uploaded_document(None, "Default lightweight") + self.assertEqual(len(outputs), 11) + self.assertEqual(outputs[10], []) + + # Missing-file rejection. + outputs = space_app.parse_uploaded_document( + _UploadedFile("/tmp/zsgdp-does-not-exist-xyz.md"), "Default lightweight" + ) + self.assertEqual(len(outputs), 11) + self.assertEqual(outputs[10], []) + + def test_normal_upload_passes_guards(self): + if space_app is None: + self.skipTest(APP_IMPORT_ERROR) + + with tempfile.TemporaryDirectory() as tmp: + input_path = Path(tmp) / "ok.md" + input_path.write_text("# OK\n\nA normal document.\n", encoding="utf-8") + outputs = space_app.parse_uploaded_document( + _UploadedFile(str(input_path)), "Default lightweight" + ) + + summary = outputs[1] + self.assertNotIn("rejected", summary) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e522fd9d5a8309b1c47b546ca15b3d9e2df395 --- /dev/null +++ b/tests/test_artifacts.py @@ -0,0 +1,82 @@ +import json +import tempfile +import unittest +from pathlib import Path + +from zsgdp.artifacts import MANIFEST_SCHEMA_VERSION, validate_artifact_manifest +from zsgdp.cli import main +from zsgdp.pipeline import parse_document +from zsgdp.schema import SCHEMA_VERSION + + +class ArtifactManifestTests(unittest.TestCase): + def test_parse_writes_valid_artifact_manifest(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + input_path = tmp_path / "sample.md" + output_dir = tmp_path / "out" + input_path.write_text("# Report\n\nHello world.\n", encoding="utf-8") + + parsed = parse_document(input_path, output_dir) + manifest = json.loads((output_dir / "artifact_manifest.json").read_text(encoding="utf-8")) + validation = validate_artifact_manifest(output_dir) + + self.assertEqual(manifest["doc_id"], parsed.doc_id) + self.assertEqual(manifest["counts"]["chunks"], len(parsed.chunks)) + self.assertTrue(any(record["path"] == "parsed_document.json" for record in manifest["files"])) + self.assertTrue(validation["valid"]) + self.assertEqual(validation["checked_count"], manifest["artifact_count"]) + + def test_manifest_records_schema_versions(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + input_path = tmp_path / "sample.md" + output_dir = tmp_path / "out" + input_path.write_text("# Report\n\nHello.\n", encoding="utf-8") + + parsed = parse_document(input_path, output_dir) + manifest = json.loads((output_dir / "artifact_manifest.json").read_text(encoding="utf-8")) + + # Manifest format version is its own integer; parsed-document + # schema version is a string echoed from the dataclass. + self.assertEqual(manifest["schema_version"], MANIFEST_SCHEMA_VERSION) + self.assertEqual(manifest["parsed_document_schema_version"], SCHEMA_VERSION) + self.assertEqual(parsed.schema_version, SCHEMA_VERSION) + + # Validation echoes both versions so callers can gate on them. + validation = validate_artifact_manifest(output_dir) + self.assertEqual(validation["manifest_schema_version"], MANIFEST_SCHEMA_VERSION) + self.assertEqual(validation["parsed_document_schema_version"], SCHEMA_VERSION) + + def test_validate_artifact_manifest_detects_checksum_mismatch(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + input_path = tmp_path / "sample.md" + output_dir = tmp_path / "out" + input_path.write_text("# Report\n\nHello world.\n", encoding="utf-8") + parse_document(input_path, output_dir) + + (output_dir / "document.md").write_text("tampered\n", encoding="utf-8") + validation = validate_artifact_manifest(output_dir) + + self.assertFalse(validation["valid"]) + self.assertTrue(any("SHA-256 mismatch: document.md" == error for error in validation["errors"])) + + def test_validate_artifacts_cli_writes_report(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + input_path = tmp_path / "sample.md" + output_dir = tmp_path / "out" + report_path = tmp_path / "validation.json" + input_path.write_text("# Report\n\nHello world.\n", encoding="utf-8") + parse_document(input_path, output_dir) + + code = main(["validate-artifacts", "--parsed", str(output_dir), "--output", str(report_path)]) + + self.assertEqual(code, 0) + self.assertTrue(report_path.exists()) + self.assertTrue(json.loads(report_path.read_text(encoding="utf-8"))["valid"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9601a4eaadebc76409a9ae7de2457605b4023de5 --- /dev/null +++ b/tests/test_benchmark.py @@ -0,0 +1,55 @@ +import tempfile +import unittest +from pathlib import Path + +from zsgdp.benchmarks.parser_quality import run_parser_benchmark +from zsgdp.cli import main + + +class BenchmarkTests(unittest.TestCase): + def test_run_parser_benchmark_writes_results(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + docs = tmp_path / "docs" + out = tmp_path / "bench" + docs.mkdir() + (docs / "one.md").write_text("# One\n\nHello world", encoding="utf-8") + + summary = run_parser_benchmark(docs, out) + + self.assertEqual(summary["document_count"], 1) + self.assertIn("fixed_token_baseline", summary["documents"][0]["chunk_strategy_counts"]) + self.assertTrue(summary["chunk_strategy_leaderboard"]) + self.assertIn("structure_quality", summary) + self.assertIn("chunking_quality", summary) + self.assertIn("throughput", summary) + self.assertIn("ablation_plan", summary) + self.assertTrue((out / "results.json").exists()) + self.assertTrue((out / "leaderboard.csv").exists()) + self.assertTrue((out / "parser_runs.csv").exists()) + self.assertTrue((out / "chunk_runs.csv").exists()) + self.assertTrue((out / "structure_runs.csv").exists()) + self.assertTrue((out / "chunk_quality.csv").exists()) + self.assertTrue((out / "throughput_runs.csv").exists()) + self.assertTrue((out / "ablations.json").exists()) + + def test_benchmark_cli(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + docs = tmp_path / "docs" + out = tmp_path / "bench" + docs.mkdir() + (docs / "one.md").write_text("# One\n\nHello world", encoding="utf-8") + + code = main(["benchmark", "--input", str(docs), "--output", str(out), "--parsers", "text"]) + + self.assertEqual(code, 0) + self.assertTrue((out / "leaderboard.csv").exists()) + self.assertTrue((out / "chunk_runs.csv").exists()) + self.assertTrue((out / "structure_runs.csv").exists()) + self.assertTrue((out / "chunk_quality.csv").exists()) + self.assertTrue((out / "throughput_runs.csv").exists()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_chunking.py b/tests/test_chunking.py new file mode 100644 index 0000000000000000000000000000000000000000..65718db190bed7469ae42b5abb27d88728a3c2a4 --- /dev/null +++ b/tests/test_chunking.py @@ -0,0 +1,286 @@ +import unittest + +from zsgdp.chunking import build_agentic_chunks +from zsgdp.config import load_config +from zsgdp.schema import DocumentProfile, Element, FigureObject, PageProfile, ParsedDocument, QualityReport, TableObject +from zsgdp.verify import verify_chunks + + +class ChunkingTests(unittest.TestCase): + def test_agentic_chunking_builds_parent_child_chunks(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.md", + file_type="markdown", + page_count=1, + extension=".md", + pages=[PageProfile(page_num=1, digital_text_chars=120, digital_text_quality=1.0)], + ) + parsed = ParsedDocument( + doc_id="d1", + source_path="sample.md", + file_type="markdown", + quality_report=QualityReport(score=0.95), + ) + parsed.elements.extend( + [ + Element("e1", "d1", 1, "title", markdown="# Report", reading_order=1, source_parser="text"), + Element("e2", "d1", 1, "paragraph", text=" ".join(["alpha"] * 80), reading_order=2, source_parser="text"), + ] + ) + + chunks = build_agentic_chunks(parsed, profile, load_config()) + + self.assertTrue(any(chunk.content_type == "parent" for chunk in chunks)) + self.assertTrue(any(chunk.parent_chunk_id for chunk in chunks)) + self.assertEqual(parsed.provenance["chunking"]["plan"]["target_tokens"], 512) + + def test_chunk_readiness_adds_metrics(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.md", + file_type="markdown", + page_count=1, + extension=".md", + pages=[PageProfile(page_num=1, digital_text_chars=120, digital_text_quality=1.0)], + ) + parsed = ParsedDocument( + doc_id="d1", + source_path="sample.md", + file_type="markdown", + quality_report=QualityReport(score=0.95), + ) + parsed.elements.append( + Element("e1", "d1", 1, "paragraph", text=" ".join(["alpha"] * 80), reading_order=1, source_parser="text") + ) + parsed.chunks = build_agentic_chunks(parsed, profile, load_config()) + + report = verify_chunks(parsed, load_config()) + + self.assertEqual(report.metrics["chunk_count"], len(parsed.chunks)) + self.assertIn("fixed_token_baseline", report.metrics["chunk_strategy_counts"]) + self.assertIn("recursive_structure", report.metrics["chunk_strategy_counts"]) + + def test_fixed_token_baseline_chunks_are_emitted_with_provenance(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.md", + file_type="markdown", + page_count=2, + extension=".md", + pages=[ + PageProfile(page_num=1, digital_text_chars=120, digital_text_quality=1.0), + PageProfile(page_num=2, digital_text_chars=120, digital_text_quality=1.0), + ], + ) + parsed = ParsedDocument( + doc_id="d1", + source_path="sample.md", + file_type="markdown", + quality_report=QualityReport(score=0.95), + ) + parsed.elements.extend( + [ + Element("e1", "d1", 1, "paragraph", text=" ".join(["alpha"] * 18), reading_order=1, source_parser="text"), + Element("e2", "d1", 2, "paragraph", text=" ".join(["beta"] * 18), reading_order=1, source_parser="text"), + ] + ) + config = load_config(overrides={"chunking": {"target_tokens": 10, "overlap_ratio": 0.2}}) + + chunks = build_agentic_chunks(parsed, profile, config) + baseline_chunks = [chunk for chunk in chunks if chunk.strategy == "fixed_token_baseline"] + + self.assertGreaterEqual(len(baseline_chunks), 4) + self.assertEqual(baseline_chunks[0].element_ids, ["e1"]) + self.assertEqual(baseline_chunks[-1].page_end, 2) + self.assertEqual(parsed.provenance["chunking"]["fixed_token_baseline_count"], len(baseline_chunks)) + + def test_figure_without_caption_still_gets_visual_chunk(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=20, digital_text_quality=1.0)], + ) + parsed = ParsedDocument( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + quality_report=QualityReport(score=0.90), + ) + parsed.elements.append(Element("e1", "d1", 1, "paragraph", text="hello world", reading_order=1, source_parser="pymupdf")) + parsed.figures.append( + FigureObject( + figure_id="f1", + page_num=1, + image_path="/tmp/figure.png", + confidence=0.5, + source_parser="pymupdf", + ) + ) + + parsed.chunks = build_agentic_chunks(parsed, profile, load_config()) + report = verify_chunks(parsed, load_config()) + + self.assertTrue(any(chunk.figure_ids == ["f1"] for chunk in parsed.chunks)) + self.assertEqual(report.metrics["figure_chunk_coverage"], 1.0) + + def test_table_chunk_keeps_multimodal_metadata(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=20, digital_text_quality=1.0)], + ) + parsed = ParsedDocument( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + quality_report=QualityReport(score=0.90), + ) + parsed.elements.append(Element("e1", "d1", 1, "paragraph", text="hello world", reading_order=1, source_parser="pymupdf")) + parsed.tables.append( + TableObject( + table_id="t1", + page_nums=[1], + bbox=[(1.0, 2.0, 3.0, 4.0)], + markdown="| A | B |\n| --- | --- |\n| 1 | 2 |", + natural_language_rendering="Table with columns A, B. Rows: 1: B=2.", + confidence=0.82, + source_parser="pymupdf", + provenance={"crop_path": "/tmp/table.png", "source_parsers": ["pymupdf", "docling"]}, + ) + ) + + parsed.chunks = build_agentic_chunks(parsed, profile, load_config()) + table_chunk = next(chunk for chunk in parsed.chunks if chunk.strategy == "table_object") + + self.assertEqual(table_chunk.text, "Table with columns A, B. Rows: 1: B=2.") + self.assertEqual(table_chunk.metadata["markdown"], "| A | B |\n| --- | --- |\n| 1 | 2 |") + self.assertEqual(table_chunk.metadata["bbox"], [(1.0, 2.0, 3.0, 4.0)]) + self.assertEqual(table_chunk.metadata["crop_path"], "/tmp/table.png") + self.assertEqual(table_chunk.metadata["source_parsers"], ["pymupdf", "docling"]) + + def test_vision_guided_chunking_exports_visual_regions(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=20, digital_text_quality=1.0)], + ) + parsed = ParsedDocument( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + quality_report=QualityReport(score=0.90), + ) + parsed.elements.append(Element("e1", "d1", 1, "paragraph", text="hello world", reading_order=1, source_parser="pymupdf")) + parsed.tables.append(TableObject(table_id="t1", page_nums=[1], bbox=[(1.0, 2.0, 3.0, 4.0)], markdown="| A | B |\n| --- | --- |\n| 1 | 2 |")) + parsed.figures.append(FigureObject(figure_id="f1", page_num=1, bbox=(5.0, 6.0, 7.0, 8.0), source_parser="pymupdf")) + config = load_config(overrides={"chunking": {"vision_guided": True}}) + + parsed.chunks = build_agentic_chunks(parsed, profile, config) + + visual_chunks = [chunk for chunk in parsed.chunks if chunk.content_type in {"table", "figure"}] + self.assertTrue(all(chunk.requires_visual_context for chunk in visual_chunks)) + self.assertEqual(len(parsed.provenance["chunking"]["vision_regions"]), 2) + self.assertEqual(parsed.provenance["chunking"]["vision_regions"][0]["region_id"], "t1") + + def test_advanced_chunking_flags_emit_strategy_chunks(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=2, + extension=".pdf", + pages=[ + PageProfile(page_num=1, digital_text_chars=200, digital_text_quality=1.0), + PageProfile(page_num=2, digital_text_chars=200, digital_text_quality=1.0), + ], + ) + parsed = ParsedDocument( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + quality_report=QualityReport(score=0.92), + ) + parsed.elements.extend( + [ + Element("e1", "d1", 1, "heading", markdown="## Revenue", reading_order=1, source_parser="pymupdf"), + Element( + "e2", + "d1", + 1, + "paragraph", + text="Revenue increased by 12 percent in Q1. Gross margin improved due to pricing.", + reading_order=2, + source_parser="pymupdf", + ), + Element("e3", "d1", 2, "heading", markdown="## Safety", reading_order=1, source_parser="pymupdf"), + Element( + "e4", + "d1", + 2, + "paragraph", + text="Safety inspections found three unresolved risks. Corrective actions are due in June.", + reading_order=2, + source_parser="pymupdf", + ), + ] + ) + parsed.tables.append( + TableObject( + table_id="t1", + page_nums=[1], + markdown="| Metric | Value |\n| --- | --- |\n| Revenue | 12% |", + natural_language_rendering="Table t1 reports revenue growth of 12 percent.", + source_parser="pymupdf", + ) + ) + parsed.figures.append( + FigureObject( + figure_id="f1", + page_num=2, + caption="Risk trend chart shows open safety findings.", + source_parser="pymupdf", + ) + ) + config = load_config( + overrides={ + "chunking": { + "contextual_retrieval": True, + "semantic_chunking": True, + "late_chunking": True, + "vision_guided": True, + "agentic_proposition_chunking": True, + } + } + ) + + parsed.chunks = build_agentic_chunks(parsed, profile, config) + strategies = {chunk.strategy for chunk in parsed.chunks} + + self.assertIn("semantic", strategies) + self.assertIn("late", strategies) + self.assertIn("contextual_retrieval", strategies) + self.assertIn("vision_guided", strategies) + self.assertIn("agentic_proposition", strategies) + self.assertGreater(parsed.provenance["chunking"]["semantic_chunk_count"], 0) + self.assertGreater(parsed.provenance["chunking"]["late_chunk_count"], 0) + self.assertGreater(parsed.provenance["chunking"]["contextual_retrieval_chunk_count"], 0) + semantic_chunk = next(chunk for chunk in parsed.chunks if chunk.strategy == "semantic") + self.assertEqual(semantic_chunk.metadata["execution_mode"], "lexical_similarity_proxy") + contextual_chunk = next(chunk for chunk in parsed.chunks if chunk.strategy == "contextual_retrieval") + self.assertIn("source_chunk_id", contextual_chunk.metadata) + late_chunk = next(chunk for chunk in parsed.chunks if chunk.strategy == "late") + self.assertTrue(late_chunk.metadata["requires_token_level_embeddings"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cli_help.py b/tests/test_cli_help.py new file mode 100644 index 0000000000000000000000000000000000000000..18cec6a3f87b021c5793c445c8e48480ddad55cf --- /dev/null +++ b/tests/test_cli_help.py @@ -0,0 +1,91 @@ +"""Tests guarding CLI help text — examples must render and stay clean.""" + +from __future__ import annotations + +import io +import unittest +from contextlib import redirect_stdout + +from zsgdp.cli import _epilog, main + + +def _capture_help(argv: list[str]) -> str: + """Run `zsgdp --help` and return captured stdout. SystemExit is normal.""" + + buffer = io.StringIO() + with redirect_stdout(buffer): + try: + main(argv + ["--help"]) + except SystemExit: + pass + return buffer.getvalue() + + +class EpilogFormatterTests(unittest.TestCase): + def test_epilog_dedents_indented_source_string(self): + rendered = _epilog( + """ + zsgdp parse --input ./a --output ./b + zsgdp parse --input ./c --output ./d + """ + ) + # No double-indentation; first non-blank line begins with two spaces only. + lines = rendered.splitlines() + self.assertEqual(lines[0], "Examples:") + self.assertTrue(lines[1].startswith(" zsgdp parse")) + # No source-indent leak. + self.assertNotIn(" zsgdp", rendered) + + def test_epilog_preserves_blank_lines_as_separators(self): + rendered = _epilog( + """ + line one + + line two + """ + ) + self.assertIn("\n\n", rendered) + + +class SubcommandHelpTests(unittest.TestCase): + def test_top_level_help_lists_examples_section(self): + text = _capture_help([]) + self.assertIn("Examples:", text) + self.assertIn("zsgdp parse", text) + self.assertIn("docs/space_smoke.md", text) + + def test_parse_help_has_examples(self): + text = _capture_help(["parse"]) + self.assertIn("Examples:", text) + self.assertIn("zsgdp parse --input", text) + self.assertIn("--config configs/docling.yaml", text) + + def test_benchmark_help_covers_three_dataset_modes(self): + text = _capture_help(["benchmark"]) + self.assertIn("Examples:", text) + self.assertIn("--dataset omnidocbench", text) + self.assertIn("--dataset doclaynet", text) + + def test_benchmark_ablate_shows_merged_arm_pattern(self): + text = _capture_help(["benchmark-ablate"]) + self.assertIn("--parser docling --parser pymupdf", text) + self.assertIn("--no-merged", text) + + def test_run_gpu_tasks_documents_dry_run_vs_execute(self): + text = _capture_help(["run-gpu-tasks"]) + self.assertIn("Dry-run", text) + self.assertIn("--execute", text) + + def test_combine_benchmarks_shows_label_pairing(self): + text = _capture_help(["combine-benchmarks"]) + self.assertIn("--label omnidocbench", text) + self.assertIn("--label doclaynet", text) + + def test_preflight_help_documents_skip_flags(self): + text = _capture_help(["preflight"]) + self.assertIn("--benchmark", text) + self.assertIn("--skip-unit", text) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_conflict_detection.py b/tests/test_conflict_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..91a7e0ceb7d5709b8f009226e0b6ec2c1d31368e --- /dev/null +++ b/tests/test_conflict_detection.py @@ -0,0 +1,89 @@ +import tempfile +import unittest +from pathlib import Path + +from zsgdp.export import export_parsed_document +from zsgdp.merge.conflict_detection import build_candidate_conflict_report, detect_candidate_conflicts +from zsgdp.merge.merge_candidates import merge_candidates +from zsgdp.schema import DocumentProfile, Element, PageProfile, ParseCandidate, TableObject + + +class ConflictDetectionTests(unittest.TestCase): + def test_conflict_report_flags_reading_order_and_table_structure(self): + candidates = [_candidate("docling", ["Alpha", "Beta", "Gamma"], 3), _candidate("pymupdf", ["Gamma", "Beta", "Alpha"], 2)] + + report = build_candidate_conflict_report(candidates) + issues = detect_candidate_conflicts(candidates) + + conflict_types = {conflict["type"] for conflict in report["conflicts"]} + self.assertIn("reading_order_disagreement", conflict_types) + self.assertIn("table_structure_disagreement", conflict_types) + self.assertTrue(issues) + self.assertTrue(all(issue.issue_type == "parser_disagreement" for issue in issues)) + + def test_merge_stores_and_exports_conflict_report(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=30)], + ) + parsed = merge_candidates( + [_candidate("docling", ["Alpha", "Beta", "Gamma"], 3), _candidate("pymupdf", ["Gamma", "Beta", "Alpha"], 2)], + profile, + ) + + with tempfile.TemporaryDirectory() as tmp: + output_dir = Path(tmp) / "out" + export_parsed_document(parsed, output_dir) + + self.assertTrue((output_dir / "conflict_report.json").exists()) + + self.assertIn("conflict_report", parsed.provenance) + self.assertGreater(parsed.provenance["conflict_report"]["conflict_count"], 0) + + +def _candidate(parser_name: str, ordered_text: list[str], table_columns: int) -> ParseCandidate: + elements = [ + Element( + element_id=f"{parser_name}_e{index}", + doc_id="d1", + page_num=1, + type="paragraph", + text=text, + reading_order=index, + confidence=0.8, + source_parser=parser_name, + ) + for index, text in enumerate(ordered_text, start=1) + ] + return ParseCandidate( + parser_name=parser_name, + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + pages=[{"page_num": 1, "source_parser": parser_name}], + elements=elements, + tables=[ + TableObject( + table_id=f"{parser_name}_t1", + page_nums=[1], + markdown=_table_markdown(table_columns), + confidence=0.8, + source_parser=parser_name, + ) + ], + confidence=0.8, + ) + + +def _table_markdown(columns: int) -> str: + if columns == 3: + return "| Region | Q1 | Q2 |\n| --- | --- | --- |\n| NA | 10 | 12 |" + return "| Region | Q1 |\n| --- | --- |\n| NA | 10 |" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cross_dataset.py b/tests/test_cross_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..40ad561f714db74720f85a2b2242aa81fd1ac634 --- /dev/null +++ b/tests/test_cross_dataset.py @@ -0,0 +1,123 @@ +"""Tests for cross-dataset benchmark comparison.""" + +from __future__ import annotations + +import json +import tempfile +import unittest +from pathlib import Path + +from zsgdp.benchmarks.cross_dataset import ( + combine_benchmark_summaries, + write_cross_dataset_outputs, +) + + +def _summary(dataset_name: str, *, layout_f1: float, leaderboard: list[dict] | None = None) -> dict: + return { + "dataset_name": dataset_name, + "dataset_root": f"/tmp/{dataset_name}", + "document_count": 5, + "mean_quality_score": 0.9, + "mean_layout_f1": layout_f1, + "mean_retrieval_recall_at_5": 0.7, + "mean_table_structure_score": 0.6, + "mean_formula_cer": 0.2, + "per_parser_gt_leaderboard": leaderboard or [], + } + + +class TestCombineBenchmarkSummaries(unittest.TestCase): + def test_two_runs_produce_two_rows(self): + runs = [ + ("docs_a", _summary("docs_a", layout_f1=0.5)), + ("docs_b", _summary("docs_b", layout_f1=0.8)), + ] + comparison = combine_benchmark_summaries(runs) + self.assertEqual(comparison["run_count"], 2) + self.assertEqual(comparison["labels"], ["docs_a", "docs_b"]) + self.assertEqual([row["label"] for row in comparison["dataset_summary"]], ["docs_a", "docs_b"]) + layouts = {row["label"]: row["mean_layout_f1"] for row in comparison["dataset_summary"]} + self.assertEqual(layouts, {"docs_a": 0.5, "docs_b": 0.8}) + + def test_parser_matrix_aligns_parsers_across_runs(self): + leaderboard_a = [ + {"parser": "docling", "mean_layout_class_aware_f1": 0.9, "document_count": 3}, + {"parser": "pymupdf", "mean_layout_class_aware_f1": 0.4, "document_count": 3}, + ] + leaderboard_b = [ + {"parser": "docling", "mean_layout_class_aware_f1": 0.7, "document_count": 5}, + # marker only appears in run B. + {"parser": "marker", "mean_layout_class_aware_f1": 0.6, "document_count": 5}, + ] + runs = [ + ("a", _summary("a", layout_f1=0.5, leaderboard=leaderboard_a)), + ("b", _summary("b", layout_f1=0.7, leaderboard=leaderboard_b)), + ] + comparison = combine_benchmark_summaries(runs) + + matrix = comparison["parser_matrix"] + parsers = sorted(row["parser"] for row in matrix) + self.assertEqual(parsers, ["docling", "marker", "pymupdf"]) + + by_parser = {row["parser"]: row for row in matrix} + # Docling appears in both runs. + self.assertEqual(by_parser["docling"]["a__mean_layout_class_aware_f1"], 0.9) + self.assertEqual(by_parser["docling"]["b__mean_layout_class_aware_f1"], 0.7) + # Marker missing in run A -> None, present in B. + self.assertIsNone(by_parser["marker"]["a__mean_layout_class_aware_f1"]) + self.assertEqual(by_parser["marker"]["b__mean_layout_class_aware_f1"], 0.6) + # PyMuPDF missing in run B -> None. + self.assertIsNone(by_parser["pymupdf"]["b__mean_layout_class_aware_f1"]) + + def test_duplicate_labels_raise(self): + with self.assertRaises(ValueError): + combine_benchmark_summaries( + [ + ("same", _summary("a", layout_f1=0.5)), + ("same", _summary("b", layout_f1=0.7)), + ] + ) + + def test_summary_loaded_from_path(self): + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + (tmp / "results.json").write_text(json.dumps(_summary("from_path", layout_f1=0.42))) + + comparison = combine_benchmark_summaries([("a", tmp)]) + self.assertEqual(comparison["dataset_summary"][0]["mean_layout_f1"], 0.42) + + def test_missing_metric_yields_none_not_zero(self): + # A summary missing mean_formula_cer (older code, e.g.) preserves None. + sparse_summary = {"dataset_name": "old_run", "document_count": 1} + comparison = combine_benchmark_summaries([("old", sparse_summary)]) + row = comparison["dataset_summary"][0] + self.assertEqual(row["document_count"], 1) + self.assertIsNone(row["mean_layout_f1"]) + self.assertIsNone(row["mean_formula_cer"]) + + +class TestWriteCrossDatasetOutputs(unittest.TestCase): + def test_writes_json_and_csvs(self): + leaderboard = [{"parser": "docling", "mean_layout_class_aware_f1": 0.9, "document_count": 3}] + comparison = combine_benchmark_summaries( + [("a", _summary("a", layout_f1=0.5, leaderboard=leaderboard))] + ) + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + write_cross_dataset_outputs(comparison, tmp) + + self.assertTrue((tmp / "cross_dataset_comparison.json").exists()) + self.assertTrue((tmp / "dataset_summary.csv").exists()) + self.assertTrue((tmp / "parser_matrix.csv").exists()) + + ds_csv = (tmp / "dataset_summary.csv").read_text() + self.assertIn("mean_layout_f1", ds_csv.splitlines()[0]) + self.assertIn("a", ds_csv.splitlines()[1]) + + matrix_csv = (tmp / "parser_matrix.csv").read_text() + self.assertIn("a__mean_layout_class_aware_f1", matrix_csv.splitlines()[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..3c941e219d62400d740044b412768d733fc73f62 --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,152 @@ +"""Dataset loader tests.""" + +from __future__ import annotations + +import json +import tempfile +import unittest +from pathlib import Path + +from zsgdp.benchmarks.datasets import ( + DatasetDocument, + get_dataset_loader, + iter_dataset, + list_dataset_loaders, + register_dataset_loader, +) + + +class TestDatasetRegistry(unittest.TestCase): + def test_built_in_loaders_registered(self): + loaders = list_dataset_loaders() + self.assertIn("custom_folder", loaders) + self.assertIn("omnidocbench", loaders) + self.assertIn("doclaynet", loaders) + + def test_custom_alias_resolves_to_custom_folder(self): + loader_default = get_dataset_loader("default") + loader_alias = get_dataset_loader("custom") + loader_canonical = get_dataset_loader("custom_folder") + self.assertIs(loader_default, loader_canonical) + self.assertIs(loader_alias, loader_canonical) + + def test_unknown_loader_raises(self): + with self.assertRaises(KeyError): + get_dataset_loader("not_a_real_dataset") + + +class TestCustomFolderLoader(unittest.TestCase): + def test_yields_files_with_no_ground_truth(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + (root / "a.md").write_text("# A\n") + (root / "b.md").write_text("# B\n") + (root / "subdir").mkdir() + (root / "subdir" / "ignored.md").write_text("# nope\n") + + documents = list(iter_dataset("custom_folder", root)) + + ids = sorted(document.doc_id for document in documents) + self.assertEqual(ids, ["a", "b"]) + for document in documents: + self.assertIsNone(document.ground_truth) + self.assertEqual(document.dataset_id, "custom_folder") + self.assertTrue(document.path.exists()) + + def test_missing_root_raises(self): + with self.assertRaises(FileNotFoundError): + list(iter_dataset("custom_folder", "/tmp/this-path-should-not-exist-zsgdp")) + + +class TestOmniDocBenchLoader(unittest.TestCase): + def test_pairs_pdf_with_sibling_json(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + (root / "doc1.pdf").write_bytes(b"%PDF-1.4\n%%EOF\n") + (root / "doc1.json").write_text(json.dumps({"reading_order": ["e1", "e2"]})) + (root / "doc2.pdf").write_bytes(b"%PDF-1.4\n%%EOF\n") # no GT + + documents = list(iter_dataset("omnidocbench", root)) + + by_id = {document.doc_id: document for document in documents} + self.assertEqual(set(by_id), {"doc1", "doc2"}) + + self.assertIsNotNone(by_id["doc1"].ground_truth) + self.assertEqual(by_id["doc1"].ground_truth["reading_order"], ["e1", "e2"]) + self.assertTrue(by_id["doc1"].metadata["has_ground_truth"]) + + self.assertIsNone(by_id["doc2"].ground_truth) + self.assertFalse(by_id["doc2"].metadata["has_ground_truth"]) + + def test_no_pdfs_raises(self): + with tempfile.TemporaryDirectory() as tmp: + with self.assertRaises(FileNotFoundError): + list(iter_dataset("omnidocbench", tmp)) + + +class TestDocLayNetLoader(unittest.TestCase): + def test_yields_one_document_per_image_with_filtered_annotations(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + (root / "page1.png").write_bytes(b"\x89PNG\r\n\x1a\n") + (root / "page2.png").write_bytes(b"\x89PNG\r\n\x1a\n") + (root / "annotations.json").write_text( + json.dumps( + { + "images": [ + {"id": 1, "file_name": "page1.png", "width": 800, "height": 1100}, + {"id": 2, "file_name": "page2.png", "width": 800, "height": 1100}, + ], + "annotations": [ + {"id": 10, "image_id": 1, "category_id": 1, "bbox": [0, 0, 100, 50]}, + {"id": 11, "image_id": 1, "category_id": 2, "bbox": [0, 60, 100, 50]}, + {"id": 12, "image_id": 2, "category_id": 1, "bbox": [0, 0, 100, 50]}, + ], + "categories": [ + {"id": 1, "name": "Title"}, + {"id": 2, "name": "Text"}, + ], + } + ) + ) + + documents = list(iter_dataset("doclaynet", root)) + + by_id = {document.doc_id: document for document in documents} + self.assertEqual(set(by_id), {"page1.png", "page2.png"}) + + self.assertEqual(len(by_id["page1.png"].ground_truth["annotations"]), 2) + self.assertEqual(len(by_id["page2.png"].ground_truth["annotations"]), 1) + self.assertEqual(by_id["page1.png"].ground_truth["categories"][1]["name"], "Title") + + def test_missing_annotations_raises(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + (root / "page1.png").write_bytes(b"\x89PNG\r\n\x1a\n") + with self.assertRaises(FileNotFoundError): + list(iter_dataset("doclaynet", root)) + + +class TestRegisterDatasetLoader(unittest.TestCase): + def test_register_and_use_custom_loader(self): + marker = [] + + def fake_loader(root: Path): + marker.append(root) + yield DatasetDocument(dataset_id="fake", doc_id="x", path=root) + + register_dataset_loader("zsgdp_test_fake", fake_loader) + try: + documents = list(iter_dataset("zsgdp_test_fake", Path("/tmp/whatever"))) + finally: + from zsgdp.benchmarks.datasets import _LOADERS + + _LOADERS.pop("zsgdp_test_fake", None) + + self.assertEqual(len(documents), 1) + self.assertEqual(documents[0].dataset_id, "fake") + self.assertEqual(marker, [Path("/tmp/whatever")]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_deployment.py b/tests/test_deployment.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f465c8dcf496e115d0b65bca39765977e4f04d --- /dev/null +++ b/tests/test_deployment.py @@ -0,0 +1,43 @@ +import json +import tempfile +import unittest +from pathlib import Path + +from zsgdp.cli import main +from zsgdp.deployment import check_huggingface_space + + +class DeploymentReadinessTests(unittest.TestCase): + def test_space_check_accepts_current_project(self): + report = check_huggingface_space(Path.cwd()) + + self.assertTrue(report["valid"]) + self.assertEqual(report["target"], "huggingface_spaces") + self.assertEqual(report["space_name"], "zeroshotGPU") + self.assertEqual(report["gpu_models_target"], "zeroshotGPU") + self.assertEqual(report["failure_count"], 0) + self.assertTrue(any(check["status"] == "warn" for check in report["checks"])) + + def test_space_check_cli_writes_report(self): + with tempfile.TemporaryDirectory() as tmp: + output_path = Path(tmp) / "space_report.json" + + code = main(["space-check", "--root", str(Path.cwd()), "--output", str(output_path)]) + + self.assertEqual(code, 0) + self.assertTrue(output_path.exists()) + self.assertTrue(json.loads(output_path.read_text(encoding="utf-8"))["valid"]) + + def test_space_check_reports_missing_files(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + + report = check_huggingface_space(root) + + self.assertFalse(report["valid"]) + self.assertGreater(report["failure_count"], 0) + self.assertTrue(any(check["id"] == "required_file" and check["status"] == "fail" for check in report["checks"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_docling_parser.py b/tests/test_docling_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f1e9cda4d20adc815cbb4abe2abed8ef4c03c71f --- /dev/null +++ b/tests/test_docling_parser.py @@ -0,0 +1,39 @@ +import unittest + +from zsgdp.parsers.docling_parser import _export_markdown, normalize_docling_markdown +from zsgdp.schema import DocumentProfile, PageProfile + + +class FakeDoclingDocument: + def export_to_markdown(self): + return "# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 |" + + +class DoclingParserTests(unittest.TestCase): + def test_export_markdown_uses_docling_method(self): + self.assertEqual(_export_markdown(FakeDoclingDocument()), "# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 |") + + def test_normalize_docling_markdown_emits_schema(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=20)], + ) + + candidate = normalize_docling_markdown( + markdown="# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 |", + profile=profile, + source_path="sample.pdf", + ) + + self.assertEqual(candidate.parser_name, "docling") + self.assertEqual(len(candidate.elements), 2) + self.assertEqual(len(candidate.tables), 1) + self.assertEqual(candidate.pages[0]["source_parser"], "docling") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_embedding_retriever.py b/tests/test_embedding_retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..19bae7a1cc8228c2f6e41b87fa9d734f50dbede6 --- /dev/null +++ b/tests/test_embedding_retriever.py @@ -0,0 +1,190 @@ +"""Tests for the embedding-based retriever and the build_retriever factory.""" + +from __future__ import annotations + +import math +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from zsgdp.benchmarks.embedding_retriever import ( + EmbeddingRetriever, + build_retriever, +) +from zsgdp.benchmarks.parser_quality import run_parser_benchmark +from zsgdp.benchmarks.retrieval import LexicalRetriever, run_retrieval_for_document +from zsgdp.schema import Chunk, ParsedDocument, QualityReport + + +def _chunk(chunk_id: str, text: str) -> Chunk: + return Chunk( + chunk_id=chunk_id, + doc_id="d1", + page_start=1, + page_end=1, + section_path=[], + content_type="prose", + text=text, + token_count=len(text.split()), + ) + + +def _hashing_embedder(dim: int = 32): + """Deterministic toy embedder: tokens hashed into a fixed-dim vector. + + Uses a process-stable hash (hashlib.md5) instead of builtins.hash(), which + is randomized per Python process and would make ranking non-deterministic + across test runs. + """ + + import hashlib + + def stable_hash(token: str) -> int: + return int.from_bytes(hashlib.md5(token.encode("utf-8")).digest()[:8], "big") + + def encode(texts): + out = [] + for text in texts: + vector = [0.0] * dim + for token in text.lower().split(): + vector[stable_hash(token) % dim] += 1.0 + out.append(vector) + return out + + return encode + + +class TestEmbeddingRetriever(unittest.TestCase): + def test_finds_distinctive_chunk_with_injected_embedder(self): + chunks = [ + _chunk("c1", "Apples grow on trees in the orchard."), + _chunk("c2", "Cars drive on highways across the country."), + _chunk("c3", "Boats sail on rivers and oceans."), + ] + retriever = EmbeddingRetriever(embedder=_hashing_embedder()) + retriever.index(chunks) + + ranking = retriever.query("apples orchard", top_k=3) + self.assertEqual(ranking[0], "c1") + + def test_empty_index_returns_empty(self): + retriever = EmbeddingRetriever(embedder=_hashing_embedder()) + self.assertEqual(retriever.query("anything", top_k=3), []) + + def test_zero_norm_vector_skipped(self): + retriever = EmbeddingRetriever(embedder=lambda texts: [[0.0, 0.0, 0.0]] * len(texts)) + retriever.index([_chunk("c1", "anything")]) + # Query embedder also returns zero vector, normalization fails -> empty. + self.assertEqual(retriever.query("anything", top_k=3), []) + + def test_embedder_returning_wrong_count_raises(self): + bad = lambda texts: [[1.0]] # always returns one vector + retriever = EmbeddingRetriever(embedder=bad) + with self.assertRaises(RuntimeError): + retriever.index([_chunk("c1", "a"), _chunk("c2", "b")]) + + def test_lazy_load_path_raises_if_sentence_transformers_missing(self): + retriever = EmbeddingRetriever(model_id="fake/model") + # Force the import to fail by patching builtins.__import__. + import builtins + + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "sentence_transformers": + raise ImportError("not installed") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=fake_import): + with self.assertRaises(RuntimeError) as ctx: + retriever.index([_chunk("c1", "anything")]) + self.assertIn("sentence-transformers", str(ctx.exception)) + + +class TestBuildRetriever(unittest.TestCase): + def test_default_returns_lexical(self): + retriever = build_retriever({}) + self.assertIsInstance(retriever, LexicalRetriever) + + def test_explicit_lexical_backend(self): + retriever = build_retriever({"benchmarks": {"retriever": {"backend": "lexical"}}}) + self.assertIsInstance(retriever, LexicalRetriever) + + def test_embedding_backend_uses_gpu_models_embedding_default(self): + config = { + "benchmarks": {"retriever": {"backend": "embedding"}}, + "gpu": {"models": {"embedding": {"model_id": "custom/model", "task": "retrieval.query"}}}, + } + retriever = build_retriever(config) + self.assertIsInstance(retriever, EmbeddingRetriever) + self.assertEqual(retriever._model_id, "custom/model") + self.assertEqual(retriever._task, "retrieval.query") + + def test_explicit_model_id_overrides_gpu_default(self): + config = { + "benchmarks": {"retriever": {"backend": "embedding", "model_id": "explicit/model"}}, + "gpu": {"models": {"embedding": {"model_id": "ignored/model"}}}, + } + retriever = build_retriever(config) + self.assertEqual(retriever._model_id, "explicit/model") + + def test_unknown_backend_raises(self): + with self.assertRaises(ValueError): + build_retriever({"benchmarks": {"retriever": {"backend": "magic"}}}) + + +class TestRunRetrievalWithEmbedding(unittest.TestCase): + def test_run_retrieval_for_document_accepts_embedding_retriever(self): + parsed = ParsedDocument( + doc_id="d1", + source_path="/tmp/d1.md", + file_type="markdown", + chunks=[ + _chunk("c1", "Apples grow on trees in the orchard during autumn."), + _chunk("c2", "Submarines navigate beneath the ocean using sonar."), + ], + quality_report=QualityReport(), + ) + retriever = EmbeddingRetriever(embedder=_hashing_embedder()) + run = run_retrieval_for_document(parsed, retriever=retriever) + self.assertTrue(run["evaluated"]) + self.assertGreater(run["query_count"], 0) + for result in run["results"]: + truth = result["truths"][0] + self.assertEqual(result["retrieved"][0], truth) + + +class TestBenchmarkOptInToEmbeddingBackend(unittest.TestCase): + def test_benchmark_uses_embedding_when_config_says_so(self): + # Patch build_retriever to return an EmbeddingRetriever with our toy embedder + # so the benchmark exercises the opt-in code path without loading a real model. + toy = EmbeddingRetriever(embedder=_hashing_embedder()) + + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + (src / "doc.md").write_text( + "# Doc\n\n" + "Apples grow on trees in the orchard during autumn season.\n\n" + "Submarines navigate beneath the ocean using sonar pulses across waters.\n", + encoding="utf-8", + ) + + with patch("zsgdp.benchmarks.parser_quality.load_config") as load_config: + load_config.return_value = { + "benchmarks": {"retriever": {"backend": "embedding"}}, + } + with patch( + "zsgdp.benchmarks.embedding_retriever.build_retriever", + return_value=toy, + ) as build_call: + summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder") + + self.assertGreaterEqual(build_call.call_count, 1) + self.assertTrue(summary["documents"][0]["retrieval_evaluated"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_env_loading.py b/tests/test_env_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..2cae7bab681f0252eebdddc375c24b882f2d1f47 --- /dev/null +++ b/tests/test_env_loading.py @@ -0,0 +1,110 @@ +"""Tests for .env loading and HF_TOKEN resolution.""" + +from __future__ import annotations + +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from zsgdp.config import hf_token, load_env_file + + +class LoadEnvFileTests(unittest.TestCase): + def test_loads_simple_key_value(self): + with tempfile.TemporaryDirectory() as tmp: + env = Path(tmp) / ".env" + env.write_text("HF_TOKEN=hf_test_value_123\nOTHER=foo\n", encoding="utf-8") + + with patch.dict("os.environ", {}, clear=False): + os.environ.pop("HF_TOKEN", None) + os.environ.pop("OTHER", None) + applied = load_env_file(env) + + self.assertEqual(applied["HF_TOKEN"], "hf_test_value_123") + self.assertEqual(applied["OTHER"], "foo") + + def test_skips_comments_and_blank_lines(self): + with tempfile.TemporaryDirectory() as tmp: + env = Path(tmp) / ".env" + env.write_text( + "# top comment\n\nFOO=bar\n # indented\n\nBAZ=qux\n", + encoding="utf-8", + ) + with patch.dict("os.environ", {}, clear=False): + os.environ.pop("FOO", None) + os.environ.pop("BAZ", None) + applied = load_env_file(env) + + self.assertEqual(set(applied), {"FOO", "BAZ"}) + + def test_quoted_values_unquoted(self): + with tempfile.TemporaryDirectory() as tmp: + env = Path(tmp) / ".env" + env.write_text('A="quoted value"\nB=\'single\'\nC=plain\n', encoding="utf-8") + with patch.dict("os.environ", {}, clear=False): + for key in ("A", "B", "C"): + os.environ.pop(key, None) + applied = load_env_file(env) + + self.assertEqual(applied["A"], "quoted value") + self.assertEqual(applied["B"], "single") + self.assertEqual(applied["C"], "plain") + + def test_export_prefix_stripped(self): + with tempfile.TemporaryDirectory() as tmp: + env = Path(tmp) / ".env" + env.write_text("export FOO=bar\n", encoding="utf-8") + with patch.dict("os.environ", {}, clear=False): + os.environ.pop("FOO", None) + applied = load_env_file(env) + + self.assertEqual(applied["FOO"], "bar") + + def test_existing_env_wins_unless_override(self): + with tempfile.TemporaryDirectory() as tmp: + env = Path(tmp) / ".env" + env.write_text("FOO=from_file\n", encoding="utf-8") + + with patch.dict("os.environ", {"FOO": "from_env"}, clear=False): + applied = load_env_file(env) + # Default behaviour: don't override. + self.assertNotIn("FOO", applied) + self.assertEqual(os.environ["FOO"], "from_env") + + # With override=True, file wins. + applied = load_env_file(env, override=True) + self.assertEqual(applied["FOO"], "from_file") + self.assertEqual(os.environ["FOO"], "from_file") + + def test_missing_file_returns_empty_no_error(self): + self.assertEqual(load_env_file(Path("/tmp/zsgdp_does_not_exist.env")), {}) + + +class HFTokenResolverTests(unittest.TestCase): + def test_picks_up_hf_token(self): + with patch.dict( + "os.environ", + {"HF_TOKEN": "primary", "HUGGING_FACE_HUB_TOKEN": "secondary"}, + clear=False, + ): + self.assertEqual(hf_token(), "primary") + + def test_falls_through_alternative_names(self): + with patch.dict("os.environ", {}, clear=True): + os.environ["HUGGINGFACE_TOKEN"] = "fallback" + self.assertEqual(hf_token(), "fallback") + + def test_recognises_hf_access_token_alias(self): + with patch.dict("os.environ", {}, clear=True): + os.environ["HF_ACCESS_TOKEN"] = "from_alias" + self.assertEqual(hf_token(), "from_alias") + + def test_returns_none_when_unset(self): + with patch.dict("os.environ", {}, clear=True): + self.assertIsNone(hf_token()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_external_parser_adapters.py b/tests/test_external_parser_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..5aee8bdf75f528bba60da334b2191bdad7c7f2b5 --- /dev/null +++ b/tests/test_external_parser_adapters.py @@ -0,0 +1,69 @@ +import unittest +from unittest.mock import patch + +from zsgdp.config import load_config +from zsgdp.normalize.normalize_unstructured import normalize_unstructured_parts +from zsgdp.parsers.external import MinerUParser, OlmOCRParser, PaddleOCRParser +from zsgdp.schema import DocumentProfile, PageProfile + + +class ExternalParserAdapterTests(unittest.TestCase): + def test_command_backed_parsers_normalize_markdown(self): + cases = [ + (MinerUParser, "mineru"), + (OlmOCRParser, "olmocr"), + (PaddleOCRParser, "paddleocr"), + ] + profile = _profile() + + for parser_class, parser_name in cases: + with self.subTest(parser=parser_name), patch.object(parser_class, "available", return_value=True), patch( + "zsgdp.parsers.external.run_external_parser_to_markdown", + return_value="# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 |", + ): + candidate = parser_class().parse("sample.pdf", profile, load_config()) + + self.assertEqual(candidate.parser_name, parser_name) + self.assertEqual(candidate.elements[0].source_parser, parser_name) + self.assertEqual(len(candidate.tables), 1) + self.assertEqual(candidate.provenance["requested_pages"], [1]) + + def test_unstructured_normalizer_preserves_page_and_title_metadata(self): + class Metadata: + page_number = 2 + + class Title: + category = "Title" + metadata = Metadata() + + def __str__(self): + return "Executive Summary" + + class Narrative: + category = "NarrativeText" + metadata = Metadata() + + def __str__(self): + return "The document parser keeps provenance." + + candidate = normalize_unstructured_parts(parts=[Title(), Narrative()], profile=_profile(), source_path="sample.pdf") + + self.assertEqual(candidate.parser_name, "unstructured") + self.assertEqual(candidate.elements[0].page_num, 2) + self.assertEqual(candidate.elements[0].type, "title") + self.assertEqual(candidate.elements[0].markdown, "# Executive Summary") + + +def _profile(): + return DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=20)], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_gpu_runner.py b/tests/test_gpu_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4b896303e441b753382bc87aff51ac937b9404 --- /dev/null +++ b/tests/test_gpu_runner.py @@ -0,0 +1,185 @@ +import json +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from zsgdp.cli import main +from zsgdp.config import load_config +from zsgdp.gpu.batching import batch_gpu_tasks +from zsgdp.gpu.runner import dry_run_gpu_tasks, load_gpu_tasks, run_gpu_task_manifest +from zsgdp.gpu.worker import GPUWorker +from zsgdp.utils import write_jsonl + + +class GPURunnerTests(unittest.TestCase): + def test_batch_gpu_tasks_groups_by_task_type_and_batch_size(self): + tasks = [ + {"task_id": "a", "task_type": "figure_description", "priority": 1}, + {"task_id": "b", "task_type": "figure_description", "priority": 2}, + {"task_id": "c", "task_type": "table_vlm_repair", "priority": 3}, + ] + + batches = batch_gpu_tasks(tasks, max_batch_size=1) + + self.assertEqual(len(batches), 3) + self.assertEqual(batches[0]["task_count"], 1) + self.assertEqual({batch["task_type"] for batch in batches}, {"figure_description", "table_vlm_repair"}) + + def test_worker_reports_missing_image_path(self): + worker = GPUWorker(load_config()) + + result = worker.run( + { + "task_id": "gt1", + "task_type": "figure_description", + "doc_id": "d1", + "page_nums": [1], + "image_path": "/tmp/does-not-exist.png", + } + ) + + self.assertEqual(result["status"], "blocked_missing_inputs") + self.assertIn("image_path", result["readiness"]["missing_inputs"]) + + def test_run_gpu_task_manifest_writes_report(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + image_path = tmp_path / "figure.png" + image_path.write_bytes(b"fake") + tasks_path = tmp_path / "gpu_tasks.jsonl" + report_path = tmp_path / "report.json" + write_jsonl( + tasks_path, + [ + { + "task_id": "gt1", + "task_type": "figure_description", + "doc_id": "d1", + "page_nums": [1], + "image_path": str(image_path), + "priority": 60, + } + ], + ) + + report = run_gpu_task_manifest(tmp_path, config=load_config(), output_path=report_path) + + self.assertEqual(report["task_count"], 1) + self.assertEqual(report["ready_count"], 1) + self.assertTrue(report_path.exists()) + self.assertEqual(json.loads(report_path.read_text(encoding="utf-8"))["batch_count"], 1) + + def test_dry_run_gpu_tasks_accepts_in_memory_tasks(self): + with tempfile.TemporaryDirectory() as tmp: + image_path = Path(tmp) / "figure.png" + image_path.write_bytes(b"fake") + + report = dry_run_gpu_tasks( + [ + { + "task_id": "gt1", + "task_type": "figure_description", + "doc_id": "d1", + "page_nums": [1], + "image_path": str(image_path), + "priority": 60, + } + ], + config=load_config(), + ) + + self.assertEqual(report["ready_count"], 1) + self.assertEqual(report["blocked_count"], 0) + + def test_execute_gpu_tasks_dispatches_transformers_client(self): + with tempfile.TemporaryDirectory() as tmp: + image_path = Path(tmp) / "figure.png" + image_path.write_bytes(b"fake") + task = { + "task_id": "gt1", + "task_type": "figure_description", + "doc_id": "d1", + "page_nums": [1], + "image_path": str(image_path), + "priority": 60, + "backend": "transformers", + "model_role": "vlm", + "model_id": "local-test-model", + } + + with patch("zsgdp.gpu.worker.TransformersClient") as client_class: + client_class.return_value.execute_task.return_value = {"status": "executed", "text": "Figure description."} + report = dry_run_gpu_tasks([task], config=load_config(), dry_run=False) + + self.assertFalse(report["dry_run"]) + self.assertEqual(report["executed_count"], 1) + self.assertEqual(report["failed_count"], 0) + self.assertEqual(report["batches"][0]["status"], "execute_complete") + client_class.return_value.execute_task.assert_called_once() + + def test_load_gpu_tasks_accepts_file_path(self): + with tempfile.TemporaryDirectory() as tmp: + tasks_path = Path(tmp) / "tasks.jsonl" + write_jsonl(tasks_path, [{"task_id": "gt1"}]) + + self.assertEqual(load_gpu_tasks(tasks_path)[0]["task_id"], "gt1") + + def test_run_gpu_tasks_cli(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + tasks_path = tmp_path / "gpu_tasks.jsonl" + report_path = tmp_path / "report.json" + write_jsonl( + tasks_path, + [ + { + "task_id": "gt1", + "task_type": "figure_description", + "doc_id": "d1", + "page_nums": [1], + "image_path": str(tmp_path / "missing.png"), + "priority": 60, + } + ], + ) + + code = main(["run-gpu-tasks", "--input", str(tasks_path), "--output", str(report_path)]) + + self.assertEqual(code, 0) + self.assertTrue(report_path.exists()) + + def test_run_gpu_tasks_cli_execute(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + image_path = tmp_path / "figure.png" + image_path.write_bytes(b"fake") + tasks_path = tmp_path / "gpu_tasks.jsonl" + report_path = tmp_path / "report.json" + write_jsonl( + tasks_path, + [ + { + "task_id": "gt1", + "task_type": "figure_description", + "doc_id": "d1", + "page_nums": [1], + "image_path": str(image_path), + "priority": 60, + "backend": "transformers", + "model_role": "vlm", + "model_id": "local-test-model", + } + ], + ) + + with patch("zsgdp.gpu.worker.TransformersClient") as client_class: + client_class.return_value.execute_task.return_value = {"status": "executed", "text": "done"} + code = main(["run-gpu-tasks", "--input", str(tasks_path), "--output", str(report_path), "--execute"]) + + self.assertEqual(code, 0) + self.assertEqual(json.loads(report_path.read_text(encoding="utf-8"))["executed_count"], 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_gpu_runtime.py b/tests/test_gpu_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..f33a22e9257d4e887b2929c290c0ea702d5759fa --- /dev/null +++ b/tests/test_gpu_runtime.py @@ -0,0 +1,47 @@ +import unittest +from unittest.mock import patch + +from zsgdp.config import load_config +from zsgdp.gpu import GPUModelConfig, collect_gpu_runtime_status + + +class GPURuntimeTests(unittest.TestCase): + def test_model_config_reads_gpu_section(self): + config = load_config(overrides={"gpu": {"backend": "vllm", "provider": "huggingface_spaces", "space_name": "zeroshotGPU", "max_batch_size": 8}}) + + model_config = GPUModelConfig.from_config(config) + + self.assertEqual(model_config.backend, "vllm") + self.assertEqual(model_config.provider, "huggingface_spaces") + self.assertEqual(model_config.space_name, "zeroshotGPU") + self.assertEqual(model_config.max_batch_size, 8) + + def test_collect_runtime_detects_space_environment(self): + config = load_config() + + with patch.dict("os.environ", {"SPACE_ID": "user/zeroshotGPU", "SPACE_HARDWARE": "l4x1"}, clear=False): + status = collect_gpu_runtime_status(config).to_dict() + + self.assertEqual(status["provider"], "huggingface_spaces") + self.assertEqual(status["space_name"], "zeroshotGPU") + self.assertEqual(status["gpu_models_target"], "zeroshotGPU") + self.assertTrue(status["running_on_huggingface_space"]) + self.assertEqual(status["space_id"], "user/zeroshotGPU") + self.assertEqual(status["hardware"], "l4x1") + self.assertIn(status["device"], {"cpu", "cuda", "mps"}) + self.assertIn("torch_available", status) + self.assertEqual(status["configured_models"]["vlm"]["model_id"], "Qwen/Qwen2.5-VL-3B-Instruct") + self.assertEqual(status["configured_models"]["embedding"]["model_id"], "jinaai/jina-embeddings-v3") + + def test_collect_runtime_reports_local_note(self): + config = load_config() + + with patch.dict("os.environ", {"SPACE_ID": "", "SPACE_HOST": "", "SPACE_HARDWARE": ""}, clear=False): + status = collect_gpu_runtime_status(config) + + self.assertFalse(status.running_on_huggingface_space) + self.assertTrue(any("local run" in note for note in status.notes)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_gpu_tasks.py b/tests/test_gpu_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..422abc07d53208dd8aa50ec52135becb90f09526 --- /dev/null +++ b/tests/test_gpu_tasks.py @@ -0,0 +1,99 @@ +import unittest + +from zsgdp.config import load_config +from zsgdp.gpu import plan_gpu_tasks +from zsgdp.routing import RouteDecision +from zsgdp.routing.budgets import Budget +from zsgdp.schema import DocumentProfile, FigureObject, PageProfile, ParsedDocument, TableObject + + +class GPUTaskTests(unittest.TestCase): + def test_plan_gpu_tasks_includes_route_ocr_table_and_figure(self): + config = load_config(overrides={"chunking": {"vision_guided": True}}) + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[ + PageProfile(page_num=1, scanned_score=0.8, digital_text_chars=0, digital_text_quality=0.0), + ], + ) + parsed = ParsedDocument( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + pages=[ + { + "page_num": 1, + "parser_pages": [ + {"rendered_page": {"image_path": "/tmp/page.png"}}, + ], + } + ], + ) + parsed.tables.append( + TableObject( + table_id="t1", + page_nums=[1], + bbox=[(1.0, 2.0, 3.0, 4.0)], + markdown="| A | B |\n| --- | --- |\n| 1 | 2 |", + provenance={"crop_path": "/tmp/table.png"}, + ) + ) + parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png")) + routes = [ + RouteDecision( + page_id=1, + experts=["pymupdf", "vlm_figure_repair"], + reason="figure-heavy page", + budget=Budget(), + labels=["figure_heavy"], + ) + ] + + tasks = plan_gpu_tasks(profile, parsed, config, routes) + + task_types = [task["task_type"] for task in tasks] + self.assertIn("vlm_route_repair", task_types) + self.assertIn("ocr_page", task_types) + self.assertIn("table_vlm_repair", task_types) + self.assertIn("figure_description", task_types) + self.assertEqual(tasks[0]["task_type"], "vlm_route_repair") + self.assertTrue(all(task["provider"] == "huggingface_spaces" for task in tasks)) + self.assertTrue(all(task["space_name"] == "zeroshotGPU" for task in tasks)) + self.assertTrue(all(task["model_id"] for task in tasks)) + self.assertEqual(_task_by_type(tasks, "ocr_page")["model_role"], "ocr") + self.assertEqual(_task_by_type(tasks, "table_vlm_repair")["model_role"], "table") + self.assertEqual(_task_by_type(tasks, "figure_description")["model_role"], "vlm") + self.assertEqual(_task_by_type(tasks, "figure_description")["model_id"], "Qwen/Qwen2.5-VL-3B-Instruct") + + def test_plan_gpu_tasks_respects_max_vlm_calls(self): + config = load_config(overrides={"gpu": {"max_vlm_calls_per_doc": 1}, "chunking": {"vision_guided": True}}) + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, scanned_score=0.8)], + ) + parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf") + parsed.figures.append(FigureObject(figure_id="f1", page_num=1, image_path="/tmp/figure.png")) + + tasks = plan_gpu_tasks(profile, parsed, config) + + self.assertEqual(len(tasks), 1) + self.assertEqual(tasks[0]["task_type"], "ocr_page") + + +def _task_by_type(tasks, task_type): + for task in tasks: + if task["task_type"] == task_type: + return task + raise AssertionError(f"Missing task type: {task_type}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_layout_f1.py b/tests/test_layout_f1.py new file mode 100644 index 0000000000000000000000000000000000000000..42d46e91de8a962cd351ecd73ccf50dd340eb9b3 --- /dev/null +++ b/tests/test_layout_f1.py @@ -0,0 +1,190 @@ +"""Tests for layout F1 metric and ground-truth adapters.""" + +from __future__ import annotations + +import unittest + +from zsgdp.benchmarks.ground_truth import ( + canonical_category, + doclaynet_layout_truths, + omnidocbench_layout_truths, + parsed_layout_predictions, +) +from zsgdp.schema import Element, FigureObject, ParsedDocument, QualityReport, TableObject +from zsgdp.verify.layout_f1 import compute_layout_f1 + + +def _item(bbox, category="paragraph", page_num=1): + return {"bbox": bbox, "category": category, "page_num": page_num} + + +class TestComputeLayoutF1(unittest.TestCase): + def test_perfect_match_yields_f1_1(self): + predictions = [_item((0, 0, 100, 50)), _item((0, 60, 100, 110), "table")] + truths = [_item((0, 0, 100, 50)), _item((0, 60, 100, 110), "table")] + result = compute_layout_f1(predictions, truths) + self.assertEqual(result["class_aware"]["f1"], 1.0) + self.assertEqual(result["class_agnostic"]["f1"], 1.0) + self.assertEqual(result["class_aware"]["tp"], 2) + + def test_zero_match_yields_f1_0(self): + predictions = [_item((0, 0, 50, 50))] + truths = [_item((1000, 1000, 1100, 1100))] + result = compute_layout_f1(predictions, truths) + self.assertEqual(result["class_aware"]["f1"], 0.0) + self.assertEqual(result["class_aware"]["fp"], 1) + self.assertEqual(result["class_aware"]["fn"], 1) + + def test_iou_below_threshold_misses(self): + # 50% overlap by area in one axis only -> IoU < 0.5 + predictions = [_item((0, 0, 100, 100))] + truths = [_item((60, 0, 160, 100))] + result = compute_layout_f1(predictions, truths, iou_threshold=0.5) + self.assertEqual(result["class_aware"]["tp"], 0) + + def test_class_aware_vs_agnostic(self): + # Same bbox, different category -> agnostic matches, aware doesn't. + predictions = [_item((0, 0, 100, 100), "paragraph")] + truths = [_item((0, 0, 100, 100), "title")] + result = compute_layout_f1(predictions, truths) + self.assertEqual(result["class_aware"]["tp"], 0) + self.assertEqual(result["class_agnostic"]["tp"], 1) + + def test_per_category_breakdown(self): + predictions = [_item((0, 0, 100, 100), "title"), _item((0, 200, 100, 300), "table")] + truths = [_item((0, 0, 100, 100), "title")] + result = compute_layout_f1(predictions, truths) + per_category = result["per_category"] + self.assertEqual(per_category["title"]["tp"], 1) + self.assertEqual(per_category["table"]["fp"], 1) + + def test_empty_inputs_are_vacuously_correct(self): + self.assertEqual(compute_layout_f1([], [])["class_aware"]["f1"], 1.0) + + def test_predictions_only_yields_zero(self): + result = compute_layout_f1([_item((0, 0, 10, 10))], []) + self.assertEqual(result["class_aware"]["fp"], 1) + self.assertEqual(result["class_aware"]["f1"], 0.0) + + def test_page_num_must_match(self): + predictions = [_item((0, 0, 100, 100), "table", page_num=1)] + truths = [_item((0, 0, 100, 100), "table", page_num=2)] + result = compute_layout_f1(predictions, truths) + self.assertEqual(result["class_aware"]["tp"], 0) + + +class TestDocLayNetAdapter(unittest.TestCase): + def test_xywh_converted_and_categories_normalized(self): + ground_truth = { + "image": {"id": 5, "file_name": "p.png", "page_no": 5}, + "annotations": [ + {"image_id": 5, "category_id": 1, "bbox": [10, 20, 50, 60]}, + {"image_id": 5, "category_id": 2, "bbox": [100, 0, 40, 30]}, + ], + "categories": {1: {"name": "Title"}, 2: {"name": "Section-header"}}, + } + truths = doclaynet_layout_truths(ground_truth) + self.assertEqual(len(truths), 2) + self.assertEqual(truths[0]["bbox"], (10.0, 20.0, 60.0, 80.0)) + self.assertEqual(truths[0]["category"], "title") + self.assertEqual(truths[0]["page_num"], 5) + self.assertEqual(truths[1]["category"], "heading") + + def test_invalid_annotations_dropped(self): + ground_truth = { + "image": {"id": 1, "file_name": "p.png"}, + "annotations": [ + {"image_id": 1, "category_id": 1, "bbox": [0, 0, 0, 0]}, + {"image_id": 1, "category_id": 1}, + ], + "categories": {1: {"name": "Text"}}, + } + self.assertEqual(doclaynet_layout_truths(ground_truth), []) + + +class TestOmniDocBenchAdapter(unittest.TestCase): + def test_picks_layout_dets_first(self): + ground_truth = { + "layout_dets": [ + {"bbox": [0, 0, 100, 50], "category": "title", "page_num": 1}, + {"bbox": [0, 100, 100, 150], "category": "Table", "page": 1}, + ] + } + truths = omnidocbench_layout_truths(ground_truth) + self.assertEqual(len(truths), 2) + self.assertEqual(truths[0]["category"], "title") + self.assertEqual(truths[1]["category"], "table") + + def test_pages_nested_records(self): + ground_truth = { + "pages": [ + {"page_num": 1, "elements": [{"bbox": [0, 0, 10, 10], "category": "paragraph"}]}, + {"page_num": 2, "elements": [{"bbox": [0, 0, 10, 10], "category": "table"}]}, + ] + } + truths = omnidocbench_layout_truths(ground_truth) + self.assertEqual(len(truths), 2) + self.assertEqual(truths[0]["page_num"], 1) + self.assertEqual(truths[1]["page_num"], 2) + + def test_unknown_shape_returns_empty(self): + self.assertEqual(omnidocbench_layout_truths({"weird": "shape"}), []) + self.assertEqual(omnidocbench_layout_truths(None), []) + + +class TestParsedPredictions(unittest.TestCase): + def test_extracts_bboxes_from_elements_tables_figures(self): + parsed = ParsedDocument( + doc_id="d1", + source_path="/tmp/d1.pdf", + file_type="pdf", + elements=[ + Element( + element_id="e1", + doc_id="d1", + page_num=1, + type="title", + text="Title", + bbox=(0.0, 0.0, 100.0, 30.0), + ), + Element( + element_id="e2", + doc_id="d1", + page_num=1, + type="paragraph", + text="No bbox", + ), + ], + tables=[ + TableObject( + table_id="t1", + page_nums=[1], + bbox=[(0.0, 100.0, 200.0, 200.0)], + ) + ], + figures=[ + FigureObject( + figure_id="f1", + page_num=2, + bbox=(50.0, 50.0, 150.0, 250.0), + ) + ], + quality_report=QualityReport(), + ) + predictions = parsed_layout_predictions(parsed) + categories = sorted(prediction["category"] for prediction in predictions) + self.assertEqual(categories, ["figure", "table", "title"]) + self.assertEqual(len(predictions), 3) + + +class TestCanonicalCategory(unittest.TestCase): + def test_canonical_mapping(self): + self.assertEqual(canonical_category("Picture"), "figure") + self.assertEqual(canonical_category("Section-header"), "heading") + self.assertEqual(canonical_category("Page-footer"), "footer") + self.assertEqual(canonical_category("formula"), "formula") + self.assertEqual(canonical_category("Mystery"), "mystery") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..17c59d2f2e869031775745232fe02c98c6aad07f --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,125 @@ +"""Tests for the logging configuration and structured log emission.""" + +from __future__ import annotations + +import io +import json +import logging +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from zsgdp.logging_config import configure_logging, get_logger +from zsgdp.pipeline import parse_document + + +class ConfigureLoggingTests(unittest.TestCase): + def setUp(self) -> None: + # Reset between tests so each one configures cleanly. + root = logging.getLogger("zsgdp") + for handler in list(root.handlers): + root.removeHandler(handler) + + def test_idempotent_configuration(self): + stream = io.StringIO() + configure_logging(level="INFO", json_format=False, stream=stream) + configure_logging(level="INFO", json_format=False, stream=stream) + root = logging.getLogger("zsgdp") + # Idempotent: still exactly one handler attached. + self.assertEqual(len(root.handlers), 1) + + def test_text_format_emits_human_readable(self): + stream = io.StringIO() + configure_logging(level="INFO", json_format=False, stream=stream) + get_logger("zsgdp.test").info("hello", extra={"doc_id": "d1"}) + output = stream.getvalue() + self.assertIn("INFO", output) + self.assertIn("zsgdp.test", output) + self.assertIn("hello", output) + + def test_json_format_emits_one_line_records(self): + stream = io.StringIO() + configure_logging(level="INFO", json_format=True, stream=stream) + get_logger("zsgdp.test").info("event", extra={"doc_id": "abc", "count": 3}) + output = stream.getvalue().strip() + record = json.loads(output) + self.assertEqual(record["level"], "INFO") + self.assertEqual(record["logger"], "zsgdp.test") + self.assertEqual(record["message"], "event") + self.assertEqual(record["doc_id"], "abc") + self.assertEqual(record["count"], 3) + + def test_default_level_is_warning(self): + stream = io.StringIO() + with patch.dict("os.environ", {"ZSGDP_LOG_LEVEL": "", "ZSGDP_LOG_JSON": ""}, clear=False): + configure_logging(stream=stream) + get_logger("zsgdp.test").info("hidden_at_default_level") + self.assertNotIn("hidden_at_default_level", stream.getvalue()) + get_logger("zsgdp.test").warning("visible_at_default_level") + self.assertIn("visible_at_default_level", stream.getvalue()) + + def test_get_logger_namespacing(self): + self.assertEqual(get_logger().name, "zsgdp") + self.assertEqual(get_logger("zsgdp.foo").name, "zsgdp.foo") + # Bare names get prefixed. + self.assertEqual(get_logger("foo").name, "zsgdp.foo") + + +class PipelineLogEmissionTests(unittest.TestCase): + def test_parse_emits_start_and_end_records(self): + # Reset handlers so assertLogs works against the named logger. + root = logging.getLogger("zsgdp") + for handler in list(root.handlers): + root.removeHandler(handler) + root.setLevel(logging.DEBUG) + root.propagate = True + + with tempfile.TemporaryDirectory() as tmp: + input_path = Path(tmp) / "doc.md" + input_path.write_text("# Doc\n\nHello.\n", encoding="utf-8") + + with self.assertLogs("zsgdp.pipeline", level="INFO") as captured: + parse_document(input_path, Path(tmp) / "out") + + messages = [record.message for record in captured.records] + self.assertIn("parse_start", messages) + self.assertIn("parse_end", messages) + + # Find a parse_end record and assert structured fields are present. + parse_end = next(record for record in captured.records if record.message == "parse_end") + self.assertTrue(hasattr(parse_end, "doc_id")) + self.assertTrue(hasattr(parse_end, "elapsed_seconds")) + self.assertTrue(hasattr(parse_end, "quality_score")) + self.assertTrue(hasattr(parse_end, "element_count")) + + +class RepairLogEmissionTests(unittest.TestCase): + def test_repair_emits_iteration_record(self): + root = logging.getLogger("zsgdp") + for handler in list(root.handlers): + root.removeHandler(handler) + root.setLevel(logging.DEBUG) + root.propagate = True + + with tempfile.TemporaryDirectory() as tmp: + input_path = Path(tmp) / "report.md" + # Malformed table forces a repair iteration. + input_path.write_text( + "# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 | 3 |\n", + encoding="utf-8", + ) + + with self.assertLogs("zsgdp.repair.controller", level="INFO") as captured: + parse_document(input_path, Path(tmp) / "out") + + repair_records = [r for r in captured.records if r.message == "repair_iteration"] + self.assertGreaterEqual(len(repair_records), 1) + # Each record carries the iteration index. + for record in repair_records: + self.assertTrue(hasattr(record, "iteration")) + self.assertTrue(hasattr(record, "status")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_markdown_normalizer.py b/tests/test_markdown_normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3f23b2b4e0e64cc624a8762b521f6a8b71b8540e --- /dev/null +++ b/tests/test_markdown_normalizer.py @@ -0,0 +1,63 @@ +import unittest + +from zsgdp.normalize.markdown import markdown_to_blocks, normalize_markdown_candidate, normalize_markdown_table + + +class MarkdownNormalizerTests(unittest.TestCase): + def test_markdown_to_blocks_preserves_pages_tables_and_images(self): + markdown = """# Report + +Intro paragraph. + +| Region | Q1 | +| --- | --- | +| NA | 10 | + + + +## Figure Section + +![Chart caption](chart.png) +""" + + candidate = normalize_markdown_candidate( + markdown=markdown, + doc_id="d1", + source_path="sample.md", + file_type="markdown", + parser_name="test", + ) + + self.assertEqual([page["page_num"] for page in candidate.pages], [1, 2]) + self.assertEqual(len(candidate.tables), 1) + self.assertEqual(candidate.tables[0].page_nums, [1]) + self.assertEqual(len(candidate.figures), 1) + self.assertEqual(candidate.figures[0].page_num, 2) + self.assertEqual(candidate.figures[0].image_path, "chart.png") + + def test_normalize_markdown_table_repairs_separator(self): + table = "| A | B |\n| --- | --- |\n| 1 | 2 |" + + self.assertEqual(normalize_markdown_table(table), "| A | B |\n| --- | --- |\n| 1 | 2 |") + + def test_normalize_plain_aligned_table(self): + table = "Region Q1 Q2\nNorth America 10 12\nEurope 8 7" + + self.assertEqual( + normalize_markdown_table(table), + "| Region | Q1 | Q2 |\n| --- | --- | --- |\n| North America | 10 | 12 |\n| Europe | 8 | 7 |", + ) + + def test_markdown_to_blocks_detects_plain_aligned_table(self): + blocks = markdown_to_blocks("# Report\n\nRegion Q1 Q2\nNorth America 10 12\nEurope 8 7") + + self.assertEqual(blocks[1].block_type, "table") + + def test_markdown_to_blocks_classifies_caption(self): + blocks = markdown_to_blocks("Figure 1 Revenue trend") + + self.assertEqual(blocks[0].block_type, "caption") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_marker_parser.py b/tests/test_marker_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbca08e6a61db8ff0aeb44ee390dca15787a253 --- /dev/null +++ b/tests/test_marker_parser.py @@ -0,0 +1,73 @@ +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from zsgdp.config import load_config +from zsgdp.parsers.external import MarkerParser, _read_external_markdown, _read_marker_markdown, normalize_marker_markdown +from zsgdp.schema import DocumentProfile, PageProfile + + +class MarkerParserTests(unittest.TestCase): + def test_normalize_marker_markdown_emits_common_schema(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=20)], + ) + + candidate = normalize_marker_markdown( + markdown="# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 |\n\n![Chart](chart.png)", + profile=profile, + source_path="sample.pdf", + ) + + self.assertEqual(candidate.parser_name, "marker") + self.assertEqual(len(candidate.tables), 1) + self.assertEqual(len(candidate.figures), 1) + self.assertEqual(candidate.pages[0]["source_parser"], "marker") + + def test_marker_parser_runs_markdown_through_normalizer(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=20)], + ) + + with patch.object(MarkerParser, "available", return_value=True), patch( + "zsgdp.parsers.external.run_marker_to_markdown", + return_value="# Report\n\nBody.", + ): + candidate = MarkerParser().parse("sample.pdf", profile, load_config()) + + self.assertEqual(candidate.parser_name, "marker") + self.assertEqual(candidate.elements[0].source_parser, "marker") + self.assertEqual(candidate.provenance["requested_pages"], [1]) + + def test_read_marker_markdown_prefers_markdown_file(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + nested = root / "sample" + nested.mkdir() + (nested / "other.md").write_text("# Other", encoding="utf-8") + (nested / "markdown.md").write_text("# Preferred", encoding="utf-8") + + markdown = _read_marker_markdown(root) + + self.assertEqual(markdown, "# Preferred") + + def test_read_external_markdown_falls_back_to_stdout(self): + with tempfile.TemporaryDirectory() as tmp: + markdown = _read_external_markdown(Path(tmp), parser_name="mineru", stdout="# From stdout") + + self.assertEqual(markdown, "# From stdout") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_merge.py b/tests/test_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..376e5efe8037f7fa15d49303c63f46b68ff18ac2 --- /dev/null +++ b/tests/test_merge.py @@ -0,0 +1,134 @@ +import unittest + +from zsgdp.merge.dedupe import dedupe_elements, dedupe_tables +from zsgdp.schema import Element, TableObject + + +class MergeDedupeTests(unittest.TestCase): + def test_merges_docling_heading_with_pymupdf_bbox(self): + elements = [ + Element( + element_id="docling_p1_e1", + doc_id="d1", + page_num=1, + type="heading", + text="## Revenue Summary", + markdown="## Revenue Summary", + reading_order=1, + confidence=0.88, + source_parser="docling", + ), + Element( + element_id="pymupdf_p1_e1", + doc_id="d1", + page_num=1, + type="paragraph", + text="Revenue Summary", + bbox=(72.0, 100.0, 200.0, 124.0), + reading_order=1, + confidence=0.86, + source_parser="pymupdf", + ), + ] + + deduped = dedupe_elements(elements) + + self.assertEqual(len(deduped), 1) + self.assertEqual(deduped[0].source_parser, "docling") + self.assertEqual(deduped[0].bbox, (72.0, 100.0, 200.0, 124.0)) + self.assertEqual(deduped[0].provenance["bbox_source_parser"], "pymupdf") + + def test_drops_paragraph_duplicate_of_structured_table(self): + elements = [ + Element( + element_id="docling_p1_e1", + doc_id="d1", + page_num=1, + type="paragraph", + text="Region Q1 Q2 North America 10 12 Europe 8 7", + reading_order=1, + confidence=0.88, + source_parser="docling", + ), + Element( + element_id="pymupdf_p1_e1", + doc_id="d1", + page_num=1, + type="table", + markdown="| Region | Q1 | Q2 |\n| --- | --- | --- |\n| North America | 10 | 12 |\n| Europe | 8 | 7 |", + reading_order=1, + confidence=0.72, + source_parser="pymupdf", + ), + ] + + deduped = dedupe_elements(elements) + + self.assertEqual(len(deduped), 1) + self.assertEqual(deduped[0].type, "table") + + def test_merges_duplicate_table_elements_and_keeps_better_grid(self): + elements = [ + Element( + element_id="docling_p1_e3", + doc_id="d1", + page_num=1, + type="table", + markdown="| Region | Q1 | Q2 North America | 10 | 12 Europe | 8 | 7 |\n| --- | --- | --- | --- | --- | --- | --- |", + reading_order=3, + confidence=0.88, + source_parser="docling", + ), + Element( + element_id="pymupdf_p1_e3", + doc_id="d1", + page_num=1, + type="table", + bbox=(72.0, 144.0, 237.0, 186.0), + markdown="| Region | Q1 | Q2 |\n| --- | --- | --- |\n| North America | 10 | 12 |\n| Europe | 8 | 7 |", + reading_order=3, + confidence=0.72, + source_parser="pymupdf", + ), + ] + + deduped = dedupe_elements(elements) + + self.assertEqual(len(deduped), 1) + self.assertEqual(deduped[0].source_parser, "pymupdf") + self.assertEqual(deduped[0].confidence, 0.88) + self.assertIn("| North America | 10 | 12 |", deduped[0].markdown or "") + self.assertEqual(deduped[0].bbox, (72.0, 144.0, 237.0, 186.0)) + + def test_merges_duplicate_tables_and_keeps_better_grid_assets(self): + tables = [ + TableObject( + table_id="docling_t1", + page_nums=[1], + markdown="| Region | Q1 | Q2 North America | 10 | 12 Europe | 8 | 7 |\n| --- | --- | --- | --- | --- | --- | --- |", + confidence=0.84, + source_parser="docling", + ), + TableObject( + table_id="pymupdf_t1", + page_nums=[1], + bbox=[(72.0, 144.0, 237.0, 186.0)], + markdown="| Region | Q1 | Q2 |\n| --- | --- | --- |\n| North America | 10 | 12 |\n| Europe | 8 | 7 |", + confidence=0.72, + source_parser="pymupdf", + provenance={"crop_path": "/tmp/table.png"}, + ), + ] + + deduped = dedupe_tables(tables) + + self.assertEqual(len(deduped), 1) + self.assertEqual(deduped[0].source_parser, "pymupdf") + self.assertEqual(deduped[0].confidence, 0.84) + self.assertEqual(deduped[0].bbox, [(72.0, 144.0, 237.0, 186.0)]) + self.assertEqual(deduped[0].provenance["crop_path"], "/tmp/table.png") + self.assertEqual(deduped[0].provenance["source_parsers"], ["pymupdf", "docling"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_parser_disagreement.py b/tests/test_parser_disagreement.py new file mode 100644 index 0000000000000000000000000000000000000000..34af84c932982790a61cfc1cf8ab548c808929b8 --- /dev/null +++ b/tests/test_parser_disagreement.py @@ -0,0 +1,177 @@ +"""Tests for parser-disagreement and repair-success metrics.""" + +from __future__ import annotations + +import tempfile +import unittest +from pathlib import Path + +from zsgdp.merge.conflict_detection import build_candidate_conflict_report +from zsgdp.pipeline import parse_document +from zsgdp.schema import DocumentProfile, Element, ParseCandidate, PageProfile, TableObject +from zsgdp.verify.parser_disagreement import compute_parser_disagreement +from zsgdp.verify.repair_success import compute_repair_success + + +def _profile() -> DocumentProfile: + return DocumentProfile( + doc_id="d1", + source_path="/tmp/d1.md", + file_type="markdown", + page_count=1, + extension=".md", + pages=[PageProfile(page_num=1, digital_text_chars=400, digital_text_quality=0.9)], + ) + + +def _candidate(name: str, *, text: str, table_count: int = 0) -> ParseCandidate: + elements = [ + Element( + element_id=f"{name}_e1", + doc_id="d1", + page_num=1, + type="paragraph", + text=text, + reading_order=1, + source_parser=name, + ) + ] + tables: list[TableObject] = [] + for index in range(table_count): + tables.append( + TableObject( + table_id=f"{name}_t{index + 1}", + page_nums=[1], + markdown="| A | B |\n| --- | --- |\n| 1 | 2 |", + source_parser=name, + ) + ) + return ParseCandidate( + parser_name=name, + doc_id="d1", + source_path="/tmp/d1.md", + file_type="markdown", + elements=elements, + tables=tables, + figures=[], + pages=[{"page_num": 1, "source_parser": name}], + confidence=0.8, + ) + + +class TestParserDisagreement(unittest.TestCase): + def test_disagreement_rate_uses_pair_count_denominator(self): + candidates = [ + _candidate("docling", text="A" * 800, table_count=4), + _candidate("pymupdf", text="A" * 100, table_count=0), + ] + report = build_candidate_conflict_report(candidates) + parser_metrics = { + "docling": {"parser": "docling", "failed": False}, + "pymupdf": {"parser": "pymupdf", "failed": False}, + } + + result = compute_parser_disagreement(report, parser_metrics) + + self.assertEqual(result["candidate_count"], 2) + self.assertEqual(result["parser_pair_count"], 1) + self.assertGreater(result["conflict_count"], 0) + self.assertGreater(result["disagreement_rate"], 0.0) + self.assertIn("text_coverage_gap", result["disagreement_by_type"]) + self.assertIn("docling|pymupdf", result["disagreement_by_parser_pair"]) + + def test_disagreement_rate_zero_when_single_parser(self): + result = compute_parser_disagreement( + {"conflicts": []}, + {"docling": {"parser": "docling", "failed": False}}, + ) + + self.assertEqual(result["candidate_count"], 1) + self.assertEqual(result["parser_pair_count"], 0) + self.assertEqual(result["disagreement_rate"], 0.0) + + def test_failed_parsers_excluded_from_pair_count(self): + result = compute_parser_disagreement( + {"conflicts": []}, + { + "docling": {"parser": "docling", "failed": False}, + "marker": {"parser": "marker", "failed": True, "error": "boom"}, + "pymupdf": {"parser": "pymupdf", "failed": False}, + }, + ) + + self.assertEqual(result["candidate_count"], 2) + self.assertEqual(result["parser_pair_count"], 1) + + +class TestRepairSuccess(unittest.TestCase): + def test_resolution_rate_when_blocking_issue_resolved(self): + pre = {"score": 0.5, "issues": [{"issue_type": "invalid_table", "blocking": True, "page_num": 1, "region_id": "t1"}]} + post = {"score": 0.9, "issues": []} + history = [{"iteration": 1, "before_score": 0.5, "after_score": 0.9, "actions": [{"action": "repair_table"}]}] + + result = compute_repair_success(pre, post, history) + + self.assertEqual(result["pre_repair_blocking_count"], 1) + self.assertEqual(result["post_repair_blocking_count"], 0) + self.assertEqual(result["resolved_blocking_count"], 1) + self.assertEqual(result["repair_resolution_rate"], 1.0) + self.assertEqual(result["repair_regression_rate"], 0.0) + self.assertEqual(result["iteration_count"], 1) + self.assertAlmostEqual(result["score_delta"], 0.4, places=6) + + def test_regression_rate_counts_new_blocking_issues(self): + pre = {"score": 0.7, "issues": [{"issue_type": "invalid_table", "blocking": True, "region_id": "t1"}]} + post = { + "score": 0.6, + "issues": [ + {"issue_type": "invalid_table", "blocking": True, "region_id": "t1"}, + {"issue_type": "missing_text_coverage", "blocking": True, "page_num": 2}, + ], + } + history = [{"iteration": 1, "before_score": 0.7, "after_score": 0.6, "actions": []}] + + result = compute_repair_success(pre, post, history) + + self.assertEqual(result["resolved_blocking_count"], 0) + self.assertEqual(result["regressed_blocking_count"], 1) + self.assertEqual(result["repair_regression_rate"], 1.0) + self.assertEqual(result["repair_resolution_rate"], 0.0) + + def test_vacuous_success_when_no_pre_repair_blocking_issues(self): + result = compute_repair_success( + {"score": 1.0, "issues": []}, + {"score": 1.0, "issues": []}, + [], + ) + + self.assertEqual(result["repair_resolution_rate"], 1.0) + self.assertEqual(result["repair_regression_rate"], 0.0) + self.assertEqual(result["iteration_count"], 0) + + +class TestRepairSuccessIntegration(unittest.TestCase): + def test_pipeline_records_resolution_for_iterative_table_repair(self): + with tempfile.TemporaryDirectory() as tmp: + input_path = Path(tmp) / "report.md" + input_path.write_text( + "# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 | 3 |\n", + encoding="utf-8", + ) + + parsed = parse_document(input_path, Path(tmp) / "out") + + metrics = parsed.quality_report.metrics + self.assertIn("repair_resolution_rate", metrics) + self.assertIn("repair_regression_rate", metrics) + self.assertIn("parser_disagreement_rate", metrics) + + success = parsed.provenance["repair_success"] + self.assertGreaterEqual(success["pre_repair_issue_count"], 1) + self.assertGreaterEqual(success["resolved_any_count"], 1) + self.assertGreaterEqual(success["repair_resolution_rate_any"], 0.0) + self.assertGreater(success["iteration_count"], 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_parser_metrics.py b/tests/test_parser_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e2920b89b857bf7feaa3bad87caf1e2bdabbeb --- /dev/null +++ b/tests/test_parser_metrics.py @@ -0,0 +1,54 @@ +import unittest + +from zsgdp.schema import DocumentProfile, Element, PageProfile, ParseCandidate, TableObject +from zsgdp.verify.parser_metrics import candidate_metrics, failure_metrics + + +class ParserMetricsTests(unittest.TestCase): + def test_candidate_metrics_reports_coverage_and_valid_tables(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.md", + file_type="markdown", + page_count=1, + extension=".md", + pages=[PageProfile(page_num=1, digital_text_chars=11)], + ) + candidate = ParseCandidate( + parser_name="test", + doc_id="d1", + source_path="sample.md", + file_type="markdown", + pages=[{"page_num": 1}], + elements=[ + Element("e1", "d1", 1, "paragraph", text="hello world", bbox=(0, 0, 10, 10)), + ], + tables=[ + TableObject( + table_id="t1", + page_nums=[1], + markdown="| A | B |\n| --- | --- |\n| 1 | 2 |", + ) + ], + confidence=0.9, + ) + + metrics = candidate_metrics(candidate, profile, elapsed_seconds=0.25) + + self.assertEqual(metrics["parser"], "test") + self.assertEqual(metrics["text_coverage_ratio"], 1.0) + self.assertEqual(metrics["valid_table_ratio"], 1.0) + self.assertTrue(metrics["has_bboxes"]) + + def test_failure_metrics_records_error(self): + profile = DocumentProfile("d1", "sample.pdf", "pdf", 1, ".pdf") + + metrics = failure_metrics("docling", profile, "boom", elapsed_seconds=1.5) + + self.assertTrue(metrics["failed"]) + self.assertEqual(metrics["error"], "boom") + self.assertEqual(metrics["elapsed_seconds"], 1.5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pdf_integration.py b/tests/test_pdf_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..769c5ad8e3dc84ad915acecc07ed61ba51e9bcaa --- /dev/null +++ b/tests/test_pdf_integration.py @@ -0,0 +1,44 @@ +import importlib.util +import tempfile +import unittest +from pathlib import Path + +from zsgdp.pipeline import parse_document + + +@unittest.skipIf(importlib.util.find_spec("fitz") is None, "PyMuPDF is not installed") +class PDFIntegrationTests(unittest.TestCase): + def test_pymupdf_parse_exports_page_table_and_figure_assets(self): + import fitz # type: ignore + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + pdf_path = tmp_path / "sample.pdf" + output_dir = tmp_path / "out" + doc = fitz.open() + page = doc.new_page(width=612, height=792) + page.insert_text((72, 72), "Annual Report", fontsize=20) + page.insert_text((72, 120), "Revenue Summary", fontsize=14) + page.insert_text( + (72, 155), + "Region Q1 Q2\nNorth America 10 12\nEurope 8 7", + fontsize=11, + fontname="cour", + ) + page.draw_rect(fitz.Rect(72, 265, 260, 360)) + doc.save(pdf_path) + + parsed = parse_document(pdf_path, output_dir, selected_parsers=["pymupdf"]) + + self.assertEqual(parsed.file_type, "pdf") + self.assertEqual(len(parsed.tables), 1) + self.assertGreaterEqual(len(parsed.figures), 1) + self.assertTrue((output_dir / "assets" / "pages" / "page_0001.png").exists()) + self.assertTrue((output_dir / "assets" / "tables" / "p0001_t001.png").exists()) + self.assertTrue(any((output_dir / "assets" / "figures").glob("p0001_f*.png"))) + self.assertEqual(parsed.quality_report.metrics["table_chunk_coverage"], 1.0) + self.assertEqual(parsed.quality_report.metrics["figure_chunk_coverage"], 1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_per_parser_leaderboard.py b/tests/test_per_parser_leaderboard.py new file mode 100644 index 0000000000000000000000000000000000000000..847e8c59de678a419e065eedad08f993aef2c609 --- /dev/null +++ b/tests/test_per_parser_leaderboard.py @@ -0,0 +1,150 @@ +"""Tests for the per-parser GT-comparison leaderboard rollup.""" + +from __future__ import annotations + +import tempfile +import unittest +from pathlib import Path + +from zsgdp.benchmarks.datasets import DatasetDocument, register_dataset_loader +from zsgdp.benchmarks.datasets import _LOADERS as _DATASET_LOADERS +from zsgdp.benchmarks.parser_quality import _per_parser_gt_leaderboard, run_parser_benchmark + + +class TestPerParserGTLeaderboard(unittest.TestCase): + def test_aggregates_by_parser(self): + rows = [ + { + "parser": "docling", + "layout_evaluated": True, + "table_evaluated": True, + "formula_evaluated": False, + "layout_class_aware_f1": 0.9, + "layout_class_agnostic_f1": 0.95, + "layout_class_aware_precision": 0.9, + "layout_class_aware_recall": 0.9, + "layout_prediction_count": 10, + "table_structure_score": 0.8, + "table_match_rate": 0.9, + "table_cell_content_f1": 0.7, + "formula_cer": 0.0, + "formula_accuracy": 0.0, + "formula_exact_match_rate": 0.0, + "element_count": 10, + "table_count": 2, + "figure_count": 1, + }, + { + "parser": "docling", + "layout_evaluated": True, + "table_evaluated": False, + "formula_evaluated": False, + "layout_class_aware_f1": 0.7, + "layout_class_agnostic_f1": 0.75, + "layout_class_aware_precision": 0.7, + "layout_class_aware_recall": 0.7, + "layout_prediction_count": 8, + "table_structure_score": 0.0, + "table_match_rate": 0.0, + "table_cell_content_f1": 0.0, + "formula_cer": 0.0, + "formula_accuracy": 0.0, + "formula_exact_match_rate": 0.0, + "element_count": 8, + "table_count": 0, + "figure_count": 0, + }, + { + "parser": "pymupdf", + "layout_evaluated": True, + "table_evaluated": False, + "formula_evaluated": False, + "layout_class_aware_f1": 0.5, + "layout_class_agnostic_f1": 0.55, + "layout_class_aware_precision": 0.5, + "layout_class_aware_recall": 0.5, + "layout_prediction_count": 6, + "table_structure_score": 0.0, + "table_match_rate": 0.0, + "table_cell_content_f1": 0.0, + "formula_cer": 0.0, + "formula_accuracy": 0.0, + "formula_exact_match_rate": 0.0, + "element_count": 6, + "table_count": 0, + "figure_count": 0, + }, + ] + + leaderboard = _per_parser_gt_leaderboard(rows) + by_parser = {row["parser"]: row for row in leaderboard} + + # Docling appears once with 2 documents, 2 layout-evaluated, 1 table-evaluated. + self.assertEqual(by_parser["docling"]["document_count"], 2) + self.assertEqual(by_parser["docling"]["layout_evaluated_count"], 2) + self.assertEqual(by_parser["docling"]["table_evaluated_count"], 1) + self.assertEqual(by_parser["docling"]["formula_evaluated_count"], 0) + self.assertAlmostEqual(by_parser["docling"]["mean_layout_class_aware_f1"], 0.8, places=6) + # Table mean uses only the row that was evaluated. + self.assertAlmostEqual(by_parser["docling"]["mean_table_structure_score"], 0.8, places=6) + + # PyMuPDF appears once. + self.assertEqual(by_parser["pymupdf"]["document_count"], 1) + self.assertAlmostEqual(by_parser["pymupdf"]["mean_layout_class_aware_f1"], 0.5, places=6) + + def test_sorted_by_layout_then_table_then_formula_inverse(self): + rows = [ + {"parser": "low", "layout_evaluated": True, "layout_class_aware_f1": 0.3, "layout_class_agnostic_f1": 0.3}, + {"parser": "high", "layout_evaluated": True, "layout_class_aware_f1": 0.9, "layout_class_agnostic_f1": 0.9}, + ] + leaderboard = _per_parser_gt_leaderboard(rows) + self.assertEqual([row["parser"] for row in leaderboard], ["high", "low"]) + + def test_empty_rows_returns_empty_list(self): + self.assertEqual(_per_parser_gt_leaderboard([]), []) + + +class TestBenchmarkEmitsLeaderboard(unittest.TestCase): + def test_per_parser_gt_leaderboard_csv_written(self): + ground_truth = { + "layout_dets": [ + {"category": "title", "bbox": [0, 0, 100, 30], "page_num": 1}, + ] + } + + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + md_path = src / "doc.md" + md_path.write_text("# Doc\n\nSome text.\n", encoding="utf-8") + + def fake_loader(root: Path): + yield DatasetDocument( + dataset_id="omnidocbench", + doc_id="doc", + path=md_path, + ground_truth=ground_truth, + metadata={}, + ) + + register_dataset_loader("omnidocbench", fake_loader) + try: + summary = run_parser_benchmark(src, tmp / "out", dataset_name="omnidocbench") + finally: + from zsgdp.benchmarks.datasets import _load_omnidocbench + + _DATASET_LOADERS["omnidocbench"] = _load_omnidocbench + + self.assertIn("per_parser_gt_leaderboard", summary) + csv_path = tmp / "out" / "per_parser_gt_leaderboard.csv" + self.assertTrue(csv_path.exists()) + content = csv_path.read_text() + header = content.splitlines()[0] + self.assertIn("parser", header) + self.assertIn("mean_layout_class_aware_f1", header) + self.assertIn("layout_evaluated_count", header) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_per_parser_metrics.py b/tests/test_per_parser_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..34c217c2d4fd6bdee641cc8fef4ca49ca5b563ae --- /dev/null +++ b/tests/test_per_parser_metrics.py @@ -0,0 +1,166 @@ +"""Tests for per-parser GT-comparison metrics within a single merged run.""" + +from __future__ import annotations + +import json +import tempfile +import unittest +from pathlib import Path +from types import SimpleNamespace + +from zsgdp.benchmarks.datasets import DatasetDocument, register_dataset_loader +from zsgdp.benchmarks.datasets import _LOADERS as _DATASET_LOADERS +from zsgdp.benchmarks.parser_quality import run_parser_benchmark +from zsgdp.benchmarks.per_parser_metrics import compute_per_parser_metrics + + +def _parsed_with_candidates(candidates: dict) -> SimpleNamespace: + return SimpleNamespace(provenance={"candidates": candidates}) + + +class TestComputePerParserMetrics(unittest.TestCase): + def test_returns_one_block_per_parser_with_layout_truths(self): + candidates = { + "docling": { + "elements": [ + {"element_id": "e1", "type": "title", "page_num": 1, "bbox": [0, 0, 100, 30]}, + ], + "tables": [], + "figures": [], + }, + "pymupdf": { + "elements": [ + {"element_id": "e2", "type": "paragraph", "page_num": 1, "bbox": [200, 200, 300, 300]}, + ], + "tables": [], + "figures": [], + }, + } + layout_truths = [{"bbox": (0, 0, 100, 30), "category": "title", "page_num": 1}] + + result = compute_per_parser_metrics( + _parsed_with_candidates(candidates), + layout_truths=layout_truths, + ) + + self.assertEqual(set(result), {"docling", "pymupdf"}) + self.assertEqual(result["docling"]["layout"]["class_aware_f1"], 1.0) + # PyMuPDF predicted a paragraph far from any truth -> 0 F1. + self.assertEqual(result["pymupdf"]["layout"]["class_aware_f1"], 0.0) + # Element counts surfaced even when the parser scored zero. + self.assertEqual(result["pymupdf"]["element_count"], 1) + + def test_omits_metric_block_when_truths_empty(self): + candidates = { + "docling": { + "elements": [{"element_id": "e1", "type": "title", "page_num": 1, "bbox": [0, 0, 10, 10]}], + "tables": [], + "figures": [], + }, + } + result = compute_per_parser_metrics(_parsed_with_candidates(candidates)) + self.assertEqual(set(result["docling"]), {"parser", "element_count", "table_count", "figure_count"}) + + def test_table_and_formula_metrics_per_parser(self): + candidates = { + "docling": { + "elements": [ + {"element_id": "f1", "type": "formula", "page_num": 1, "text": "E = mc^2"}, + ], + "tables": [ + {"table_id": "t1", "page_nums": [1], "markdown": "| A | B |\n| --- | --- |\n| 1 | 2 |"}, + ], + "figures": [], + }, + "pymupdf": { + "elements": [ + {"element_id": "f2", "type": "formula", "page_num": 1, "text": "E = mc^9"}, + ], + "tables": [], + "figures": [], + }, + } + table_truths = [{"markdown": "| A | B |\n| --- | --- |\n| 1 | 2 |", "page_num": 1}] + formula_truths = [{"latex": "E = mc^2", "page_num": 1}] + + result = compute_per_parser_metrics( + _parsed_with_candidates(candidates), + table_truths=table_truths, + formula_truths=formula_truths, + ) + + # Docling matches table and formula exactly. + self.assertEqual(result["docling"]["table_structure"]["mean_table_score"], 1.0) + self.assertEqual(result["docling"]["formula"]["mean_cer"], 0.0) + # PyMuPDF's formula is one char off; table predictions empty. + self.assertGreater(result["pymupdf"]["formula"]["mean_cer"], 0.0) + self.assertEqual(result["pymupdf"]["table_structure"]["matched_pair_count"], 0) + + def test_no_candidates_returns_empty_dict(self): + parsed = SimpleNamespace(provenance={"candidates": {}}) + self.assertEqual(compute_per_parser_metrics(parsed, layout_truths=[]), {}) + + +class TestPipelinePopulatesCandidates(unittest.TestCase): + def test_candidates_serialized_to_provenance(self): + with tempfile.TemporaryDirectory() as tmp: + input_path = Path(tmp) / "doc.md" + input_path.write_text("# Doc\n\nSome content.\n", encoding="utf-8") + + from zsgdp.pipeline import parse_document + + parsed = parse_document(input_path, Path(tmp) / "out") + + candidates = parsed.provenance.get("candidates") + self.assertIsInstance(candidates, dict) + self.assertGreater(len(candidates), 0) + # text parser should be one of the candidates for markdown. + self.assertIn("text", candidates) + self.assertIn("elements", candidates["text"]) + + +class TestBenchmarkIntegration(unittest.TestCase): + def test_per_parser_csv_emitted_with_omnidocbench_truths(self): + ground_truth = { + "layout_dets": [ + {"category": "title", "bbox": [0, 0, 100, 30], "page_num": 1}, + {"category": "table", "markdown": "| A | B |\n| --- | --- |\n| 1 | 2 |", "page_num": 1}, + ] + } + + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + md_path = src / "doc.md" + md_path.write_text("# Doc\n\n| A | B |\n| --- | --- |\n| 1 | 2 |\n", encoding="utf-8") + + def fake_loader(root: Path): + yield DatasetDocument( + dataset_id="omnidocbench", + doc_id="doc", + path=md_path, + ground_truth=ground_truth, + metadata={}, + ) + + register_dataset_loader("omnidocbench", fake_loader) + try: + summary = run_parser_benchmark(src, tmp / "out", dataset_name="omnidocbench") + finally: + from zsgdp.benchmarks.datasets import _load_omnidocbench + + _DATASET_LOADERS["omnidocbench"] = _load_omnidocbench + + doc = summary["documents"][0] + self.assertIn("per_parser_metrics", doc) + self.assertGreater(len(doc["per_parser_metrics"]), 0) + csv_path = tmp / "out" / "per_parser_metrics.csv" + self.assertTrue(csv_path.exists()) + content = csv_path.read_text() + self.assertIn("parser", content.splitlines()[0]) + self.assertGreater(len(content.splitlines()), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..5c71eefe55d653082d065c00a4981e4f7f8e3a3d --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,128 @@ +import tempfile +import unittest +from pathlib import Path + +from zsgdp.cli import main +from zsgdp.pipeline import parse_document + + +class PipelineTests(unittest.TestCase): + def test_parse_document_writes_outputs(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + input_path = tmp_path / "sample.md" + output_dir = tmp_path / "out" + input_path.write_text("# Report\n\nThis is a test document.\n", encoding="utf-8") + + parsed = parse_document(input_path, output_dir) + + self.assertTrue(parsed.elements) + self.assertTrue(parsed.chunks) + self.assertEqual(parsed.provenance["config_deployment"]["gpu_models_target"], "zeroshotGPU") + self.assertEqual(parsed.provenance["gpu_runtime"]["gpu_models_target"], "zeroshotGPU") + self.assertIn("chunking", parsed.provenance) + self.assertTrue((output_dir / "parsed_document.json").exists()) + self.assertTrue((output_dir / "elements.jsonl").exists()) + self.assertTrue((output_dir / "chunks.jsonl").exists()) + self.assertTrue((output_dir / "chunking_plan.json").exists()) + self.assertTrue((output_dir / "parser_metrics.json").exists()) + self.assertTrue((output_dir / "gpu_runtime.json").exists()) + self.assertTrue((output_dir / "artifact_manifest.json").exists()) + + def test_parse_document_exports_gpu_tasks_when_visual_work_exists(self): + try: + import fitz # type: ignore + except ImportError: + self.skipTest("PyMuPDF is not installed") + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + input_path = tmp_path / "visual.pdf" + output_dir = tmp_path / "out" + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), "Report") + page.draw_rect(fitz.Rect(72, 120, 180, 180)) + doc.save(input_path) + doc.close() + + parsed = parse_document(input_path, output_dir, config_overrides={"chunking": {"vision_guided": True}}) + + self.assertTrue(parsed.provenance["gpu_tasks"]) + self.assertEqual(parsed.provenance["gpu_task_report"]["task_count"], len(parsed.provenance["gpu_tasks"])) + self.assertTrue((output_dir / "gpu_tasks.jsonl").exists()) + self.assertTrue((output_dir / "gpu_task_report.json").exists()) + + def test_parse_document_reverifies_after_repair(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + input_path = tmp_path / "table.md" + output_dir = tmp_path / "out" + input_path.write_text("# Report\n\n| A | B |\n| --- | --- |\n| 1 | 2 | 3 |\n", encoding="utf-8") + + parsed = parse_document(input_path, output_dir) + + issue_types = [issue.issue_type for issue in parsed.quality_report.issues] + self.assertNotIn("invalid_table", issue_types) + self.assertIn("pre_repair_quality", parsed.provenance) + self.assertTrue(parsed.provenance["repair_iterations"]) + self.assertIn("invalid_table", [issue["issue_type"] for issue in parsed.provenance["pre_repair_quality"]["issues"]]) + + def test_export_chunks_cli(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + input_path = tmp_path / "sample.md" + output_dir = tmp_path / "out" + chunks_out = tmp_path / "chunks.json" + input_path.write_text("# Report\n\nThis is a test document.\n", encoding="utf-8") + parse_document(input_path, output_dir) + + exit_code = main([ + "export-chunks", + "--parsed", + str(output_dir), + "--format", + "json", + "--output", + str(chunks_out), + ]) + + self.assertEqual(exit_code, 0) + self.assertTrue(chunks_out.exists()) + self.assertIn('"chunk_id"', chunks_out.read_text(encoding="utf-8")) + + def test_parse_folder_cli_uses_workers_and_unique_output_dirs(self): + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + docs = tmp_path / "docs" + output_dir = tmp_path / "out" + docs.mkdir() + (docs / "same.md").write_text("# Markdown\n\nHello world.\n", encoding="utf-8") + (docs / "same.txt").write_text("Plain text document.\n", encoding="utf-8") + + code = main([ + "parse-folder", + "--input", + str(docs), + "--output", + str(output_dir), + "--workers", + "2", + "--gpu-workers", + "1", + "--parsers", + "text", + ]) + + self.assertEqual(code, 0) + self.assertTrue((output_dir / "same" / "artifact_manifest.json").exists()) + self.assertTrue((output_dir / "same-txt" / "artifact_manifest.json").exists()) + + def test_gpu_status_cli(self): + code = main(["gpu-status"]) + + self.assertEqual(code, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_preflight.py b/tests/test_preflight.py new file mode 100644 index 0000000000000000000000000000000000000000..2bafa4971d08c800a8210d5cddf559901c31f8c7 --- /dev/null +++ b/tests/test_preflight.py @@ -0,0 +1,112 @@ +"""Tests for the preflight runner.""" + +from __future__ import annotations + +import unittest +from unittest.mock import patch + +from zsgdp.preflight import StepResult, format_failures, format_summary, run_preflight + + +class FormatSummaryTests(unittest.TestCase): + def test_pass_summary_lists_each_step(self): + from zsgdp.preflight import PreflightResult + + result = PreflightResult( + steps=[ + StepResult(name="unit", passed=True, elapsed_seconds=0.5), + StepResult(name="regression", passed=True, elapsed_seconds=0.7), + ] + ) + summary = format_summary(result) + self.assertIn("[ok] unit", summary) + self.assertIn("[ok] regression", summary) + self.assertIn("preflight: PASS", summary) + + def test_fail_summary_marks_failed_step(self): + from zsgdp.preflight import PreflightResult + + result = PreflightResult( + steps=[ + StepResult(name="unit", passed=True, elapsed_seconds=0.5), + StepResult(name="regression", passed=False, elapsed_seconds=1.2, output="snapshot drift"), + ] + ) + summary = format_summary(result) + self.assertIn("[FAIL] regression", summary) + self.assertIn("preflight: FAIL", summary) + + def test_skipped_steps_render_as_skip(self): + from zsgdp.preflight import PreflightResult + + result = PreflightResult( + steps=[StepResult(name="unit", passed=True, elapsed_seconds=0.0, skipped=True)] + ) + self.assertIn("[skip] unit", format_summary(result)) + + def test_format_failures_concatenates_outputs(self): + from zsgdp.preflight import PreflightResult + + result = PreflightResult( + steps=[ + StepResult(name="unit", passed=False, elapsed_seconds=0.0, output="boom"), + StepResult(name="regression", passed=False, elapsed_seconds=0.0, output="snapshot off"), + ] + ) + text = format_failures(result) + self.assertIn("--- unit output ---", text) + self.assertIn("boom", text) + self.assertIn("--- regression output ---", text) + self.assertIn("snapshot off", text) + + +class RunPreflightTests(unittest.TestCase): + def test_skip_flags_mark_steps_skipped(self): + with patch("zsgdp.preflight._run_step") as run_step: + run_step.return_value = StepResult(name="unit", passed=True, elapsed_seconds=0.0) + result = run_preflight( + skip_unit=True, + skip_regression=True, + skip_space_check=True, + skip_parsers=True, + ) + + self.assertEqual(run_step.call_count, 0) + self.assertTrue(all(step.skipped for step in result.steps)) + self.assertTrue(result.passed) + + def test_aggregates_pass_when_all_steps_succeed(self): + with patch( + "zsgdp.preflight._run_step", + side_effect=lambda name, command, cwd: StepResult(name=name, passed=True, elapsed_seconds=0.1), + ): + result = run_preflight() + + self.assertTrue(result.passed) + # No benchmark by default => 4 steps. + self.assertEqual(len(result.steps), 4) + + def test_failure_in_one_step_fails_overall(self): + def _step(name, command, cwd): + return StepResult(name=name, passed=(name != "regression"), elapsed_seconds=0.1, output="boom" if name == "regression" else "") + + with patch("zsgdp.preflight._run_step", side_effect=_step): + result = run_preflight() + + self.assertFalse(result.passed) + failed_names = [step.name for step in result.failed_steps] + self.assertEqual(failed_names, ["regression"]) + + def test_benchmark_step_added_when_enabled(self): + with patch( + "zsgdp.preflight._run_step", + side_effect=lambda name, command, cwd: StepResult(name=name, passed=True, elapsed_seconds=0.1), + ): + result = run_preflight(run_benchmark=True) + + names = [step.name for step in result.steps] + self.assertEqual(names, ["unit", "regression", "space_check", "parsers", "benchmark"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_profiler.py b/tests/test_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce48b3868c5092a850c295986afed398c98dcf9 --- /dev/null +++ b/tests/test_profiler.py @@ -0,0 +1,22 @@ +import tempfile +import unittest +from pathlib import Path + +from zsgdp.profiling import profile_document + + +class ProfilerTests(unittest.TestCase): + def test_text_profile_detects_table(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "sample.md" + path.write_text("# Revenue\n\n| Region | Q1 |\n| --- | --- |\n| NA | 10 |\n", encoding="utf-8") + + profile = profile_document(path) + + self.assertEqual(profile.file_type, "markdown") + self.assertEqual(profile.page_count, 1) + self.assertGreaterEqual(profile.pages[0].table_candidate_count, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pymupdf_parser.py b/tests/test_pymupdf_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4dda73e046514db24476c45633f7e2e108225e --- /dev/null +++ b/tests/test_pymupdf_parser.py @@ -0,0 +1,52 @@ +import unittest +from pathlib import Path + +from zsgdp.parsers.pymupdf_parser import ( + TextBlock, + _asset_root, + _guess_element_type, + _is_table_text, + _sort_blocks_reading_order, + _table_text_to_markdown, +) + + +class PyMuPDFParserHelperTests(unittest.TestCase): + def test_table_text_detection_and_markdown(self): + text = "Region Q1 Q2\nNorth America 10 12\nEurope 8 7" + + self.assertTrue(_is_table_text(text)) + self.assertEqual( + _table_text_to_markdown(text), + "| Region | Q1 | Q2 |\n| --- | --- | --- |\n| North America | 10 | 12 |\n| Europe | 8 | 7 |", + ) + + def test_reading_order_detects_two_columns(self): + blocks = [ + TextBlock(1, "right top", (320, 50, 500, 70), 1), + TextBlock(1, "left bottom", (50, 200, 230, 220), 2), + TextBlock(1, "left top", (50, 50, 230, 70), 3), + TextBlock(1, "right bottom", (320, 200, 500, 220), 4), + TextBlock(1, "left mid", (50, 120, 230, 140), 5), + TextBlock(1, "right mid", (320, 120, 500, 140), 6), + ] + + ordered = _sort_blocks_reading_order(blocks, 600) + + self.assertEqual([block.text for block in ordered], ["left top", "left mid", "left bottom", "right top", "right mid", "right bottom"]) + + def test_guess_element_type_for_table_and_title(self): + table = TextBlock(1, "A B\n1 2", (50, 100, 300, 160), 1, avg_font_size=10) + title = TextBlock(1, "Annual Report", (50, 40, 400, 70), 2, max_font_size=18, avg_font_size=18) + + self.assertEqual(_guess_element_type(table, 2, 10, 800), "table") + self.assertEqual(_guess_element_type(title, 1, 10, 800), "title") + + def test_asset_root_uses_runtime_output_dir(self): + root = _asset_root({"runtime": {"output_dir": "/tmp/out"}, "pdf": {"asset_dir": "pdf_assets"}}) + + self.assertEqual(root, Path("/tmp/out/pdf_assets")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_repair.py b/tests/test_repair.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b3d0a3b2c010c84f8efa155892897ceb4e5340 --- /dev/null +++ b/tests/test_repair.py @@ -0,0 +1,164 @@ +import unittest +import tempfile +from pathlib import Path +from unittest.mock import patch + +from zsgdp.repair import run_repair_loop +from zsgdp.schema import DocumentProfile, Element, FigureObject, PageProfile, ParsedDocument, QualityReport, TableObject + + +class RepairTests(unittest.TestCase): + def test_repair_adds_table_rendering_and_syncs_element(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1)], + ) + parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf", quality_report=QualityReport()) + parsed.elements.append( + Element( + element_id="e1", + doc_id="d1", + page_num=1, + type="table", + markdown="Region Q1 Q2\nNorth America 10 12", + source_parser="docling", + ) + ) + parsed.tables.append( + TableObject( + table_id="t1", + page_nums=[1], + markdown="Region Q1 Q2\nNorth America 10 12", + provenance={"element_id": "e1"}, + ) + ) + + repaired = run_repair_loop(profile, parsed, {"repair": {"enabled": True}}) + + self.assertEqual( + repaired.tables[0].markdown, + "| Region | Q1 | Q2 |\n| --- | --- | --- |\n| North America | 10 | 12 |", + ) + self.assertIn("Table with columns Region, Q1, Q2.", repaired.tables[0].natural_language_rendering or "") + self.assertEqual(repaired.elements[0].markdown, repaired.tables[0].markdown) + self.assertEqual(repaired.provenance["repair"]["status"], "executed_deterministic") + + def test_repair_attaches_figure_caption(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1)], + ) + parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf", quality_report=QualityReport()) + parsed.elements.append( + Element( + element_id="caption1", + doc_id="d1", + page_num=1, + type="caption", + text="Figure 1 Revenue trend", + reading_order=4, + ) + ) + parsed.figures.append(FigureObject(figure_id="f1", page_num=1)) + + repaired = run_repair_loop(profile, parsed, {"repair": {"enabled": True}}) + + self.assertEqual(repaired.figures[0].caption, "Figure 1 Revenue trend") + self.assertEqual(repaired.figures[0].provenance["caption_source_element_id"], "caption1") + + def test_repair_marks_repeated_headers(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=3, + extension=".pdf", + pages=[PageProfile(page_num=1), PageProfile(page_num=2), PageProfile(page_num=3)], + ) + parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf", quality_report=QualityReport()) + for page_num in [1, 2, 3]: + parsed.elements.append( + Element( + element_id=f"h{page_num}", + doc_id="d1", + page_num=page_num, + type="header", + text="Quarterly Report", + ) + ) + + repaired = run_repair_loop(profile, parsed, {"repair": {"enabled": True}}) + + self.assertEqual(repaired.provenance["repair"]["actions"][0]["element_count"], 3) + self.assertTrue(all(element.provenance["noise_candidate"] == "repeated_header_footer" for element in repaired.elements)) + + def test_gpu_table_repair_executes_and_applies_output(self): + with tempfile.TemporaryDirectory() as tmp: + crop_path = Path(tmp) / "table.png" + crop_path.write_bytes(b"fake image") + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1)], + ) + parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf", quality_report=QualityReport()) + parsed.tables.append( + TableObject( + table_id="t1", + page_nums=[1], + markdown="Region Q1 Q2\nNorth 10 12", + provenance={"crop_path": str(crop_path)}, + ) + ) + parsed.quality_report.add_issue( + "invalid_table", + "warning", + "Table t1 is not valid.", + page_num=1, + region_id="t1", + ) + + with patch("zsgdp.repair.controller.GPUWorker.run") as run: + run.return_value = { + "task_id": "gr1", + "task_type": "table_vlm_repair", + "region_id": "t1", + "status": "executed", + "readiness": {"ready": True}, + "output": { + "status": "executed", + "text": "| Region | Q1 | Q2 |\n| --- | --- | --- |\n| North | 10 | 14 |", + }, + } + + repaired = run_repair_loop( + profile, + parsed, + {"repair": {"enabled": True, "execute_gpu_escalations": True}}, + ) + + run.assert_called_once() + self.assertFalse(run.call_args.kwargs["dry_run"]) + self.assertEqual( + repaired.tables[0].markdown, + "| Region | Q1 | Q2 |\n| --- | --- | --- |\n| North | 10 | 14 |", + ) + self.assertTrue( + any(action["action"] == "apply_gpu_table_repair" for action in repaired.provenance["repair"]["actions"]) + ) + self.assertEqual(repaired.provenance["repair"]["gpu_escalation"]["executed_count"], 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_repo_hygiene.py b/tests/test_repo_hygiene.py new file mode 100644 index 0000000000000000000000000000000000000000..072edd162671d3d745625120b68ec4ca8096ed27 --- /dev/null +++ b/tests/test_repo_hygiene.py @@ -0,0 +1,82 @@ +"""Repo-hygiene tests: gitignore, pre-commit config, .env safety, CHANGELOG. + +These exist to catch regressions in the deployment-readiness layer that +would otherwise be silent. They run as part of the standard suite, so +preflight catches them too. +""" + +from __future__ import annotations + +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] + + +class GitIgnoreTests(unittest.TestCase): + def test_dotenv_is_ignored(self): + gitignore = (ROOT / ".gitignore").read_text(encoding="utf-8") + self.assertIn(".env", gitignore) + # Whitelisted: the example template must remain committable. + self.assertIn("!.env.example", gitignore) + + def test_dotenv_example_exists_and_is_redacted(self): + import re + + env_example = (ROOT / ".env.example").read_text(encoding="utf-8") + # Template must declare HF_TOKEN as a placeholder, not contain a real one. + self.assertIn("HF_TOKEN", env_example) + # Real-shape secret patterns: prefix + long alphanumeric. The prefix + # alone is fine (it appears in the variable name and docstrings). + secret_patterns = ( + r"\bhf_[A-Za-z0-9]{20,}", + r"\bsk-[A-Za-z0-9]{20,}", + r"\bghp_[A-Za-z0-9]{20,}", + ) + for pattern in secret_patterns: + match = re.search(pattern, env_example) + self.assertIsNone(match, f"Possible secret in .env.example: {pattern}") + + +class PreCommitConfigTests(unittest.TestCase): + def test_pre_commit_config_exists_and_is_yaml(self): + config_path = ROOT / ".pre-commit-config.yaml" + self.assertTrue(config_path.exists()) + text = config_path.read_text(encoding="utf-8") + # Hand-parse without depending on PyYAML at test time. We just + # check the structural anchors that the hook config must contain. + self.assertIn("repos:", text) + self.assertIn("zsgdp-preflight", text) + self.assertIn("pre-push", text) + self.assertIn("python -m zsgdp.cli preflight", text) + + def test_pre_commit_uses_pinned_repo_revision(self): + config_path = ROOT / ".pre-commit-config.yaml" + text = config_path.read_text(encoding="utf-8") + # Every external repo must be pinned to a specific rev, not a tag like `master`. + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("rev:"): + rev = stripped.split(":", 1)[1].strip().strip("'\"") + self.assertNotIn(rev.lower(), {"head", "master", "main"}) + # Must look like a version (e.g. v5.0.0) or a commit sha. + self.assertGreaterEqual(len(rev), 3) + + +class ContributingTests(unittest.TestCase): + def test_contributing_references_preflight_and_smoke(self): + text = (ROOT / "CONTRIBUTING.md").read_text(encoding="utf-8") + self.assertIn("make preflight", text) + self.assertIn("docs/space_smoke.md", text) + self.assertIn("SCHEMA_VERSION", text) + self.assertIn("HF_TOKEN", text) + + +class ChangelogTests(unittest.TestCase): + def test_changelog_exists_and_has_unreleased_section(self): + text = (ROOT / "CHANGELOG.md").read_text(encoding="utf-8") + self.assertIn("## [Unreleased]", text) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca9c0c47bcecb5f3a74b5c4715c01aff4289838 --- /dev/null +++ b/tests/test_retrieval.py @@ -0,0 +1,198 @@ +"""Tests for retrieval metrics, lexical retriever, synthetic QA, and benchmark wiring.""" + +from __future__ import annotations + +import tempfile +import unittest +from pathlib import Path + +from zsgdp.benchmarks.parser_quality import run_parser_benchmark +from zsgdp.benchmarks.retrieval import ( + LexicalRetriever, + RetrievalQuery, + run_retrieval_for_document, + synthesize_qa_set, +) +from zsgdp.schema import Chunk, ParsedDocument, QualityReport +from zsgdp.verify.retrieval import compute_retrieval_metrics + + +def _chunk(chunk_id: str, text: str, *, page: int = 1) -> Chunk: + return Chunk( + chunk_id=chunk_id, + doc_id="d1", + page_start=page, + page_end=page, + section_path=[], + content_type="prose", + text=text, + token_count=len(text.split()), + ) + + +class TestComputeRetrievalMetrics(unittest.TestCase): + def test_perfect_retrieval(self): + result = compute_retrieval_metrics([ + (["c1"], ["c1"]), + (["c2"], ["c2"]), + ]) + self.assertEqual(result["recall_at_k"][1], 1.0) + self.assertEqual(result["mean_reciprocal_rank"], 1.0) + + def test_truth_at_rank_three_yields_partial(self): + result = compute_retrieval_metrics([(["x", "y", "c1", "z"], ["c1"])]) + self.assertEqual(result["recall_at_k"][1], 0.0) + self.assertEqual(result["recall_at_k"][3], 1.0) + self.assertAlmostEqual(result["mean_reciprocal_rank"], 1 / 3) + + def test_no_hit_yields_zero_mrr(self): + result = compute_retrieval_metrics([(["x", "y"], ["c1"])]) + self.assertEqual(result["mean_reciprocal_rank"], 0.0) + self.assertEqual(result["recall_at_k"][5], 0.0) + + def test_citation_accuracy_mirrors_recall(self): + result = compute_retrieval_metrics([(["c1"], ["c1"])]) + self.assertEqual(result["citation_accuracy_at_k"][1], result["recall_at_k"][1]) + + def test_empty_queries_are_vacuous(self): + result = compute_retrieval_metrics([]) + self.assertEqual(result["query_count"], 0) + self.assertEqual(result["mean_reciprocal_rank"], 1.0) + + def test_empty_truth_sets_skipped(self): + result = compute_retrieval_metrics([(["c1"], [])]) + self.assertEqual(result["query_count"], 0) + + +class TestLexicalRetriever(unittest.TestCase): + def test_finds_distinctive_chunk(self): + chunks = [ + _chunk("c1", "Apples grow on trees in the orchard."), + _chunk("c2", "Cars drive on highways across the country."), + _chunk("c3", "Boats sail on rivers and oceans."), + ] + retriever = LexicalRetriever() + retriever.index(chunks) + + ranking = retriever.query("apples orchard", top_k=3) + self.assertEqual(ranking[0], "c1") + + def test_query_text_with_no_indexed_terms_returns_empty(self): + retriever = LexicalRetriever() + retriever.index([_chunk("c1", "Apples grow on trees.")]) + self.assertEqual(retriever.query("zzz qqq", top_k=3), []) + + def test_empty_index_returns_empty(self): + self.assertEqual(LexicalRetriever().query("anything", top_k=3), []) + + +class TestSynthesizeQASet(unittest.TestCase): + def test_picks_distinctive_sentence_per_chunk(self): + parsed = ParsedDocument( + doc_id="d1", + source_path="/tmp/d1.md", + file_type="markdown", + chunks=[ + _chunk("c1", "Apples grow on trees in the orchard. Common shared sentence."), + _chunk("c2", "Cars drive on highways across the country. Common shared sentence."), + ], + quality_report=QualityReport(), + ) + queries = synthesize_qa_set(parsed, min_sentence_tokens=3) + truths = sorted(query.truths[0] for query in queries) + self.assertEqual(truths, ["c1", "c2"]) + for query in queries: + self.assertNotIn("Common shared", query.text) + + def test_skips_chunks_with_no_distinctive_sentences(self): + parsed = ParsedDocument( + doc_id="d1", + source_path="/tmp/d1.md", + file_type="markdown", + chunks=[ + _chunk("c1", "Same sentence here."), + _chunk("c2", "Same sentence here."), + ], + quality_report=QualityReport(), + ) + queries = synthesize_qa_set(parsed, min_sentence_tokens=2) + self.assertEqual(queries, []) + + +class TestRunRetrievalForDocument(unittest.TestCase): + def test_end_to_end_retrieval_on_synthetic_doc(self): + parsed = ParsedDocument( + doc_id="d1", + source_path="/tmp/d1.md", + file_type="markdown", + chunks=[ + _chunk("c1", "Apples grow on trees in the orchard during autumn season."), + _chunk("c2", "Submarines navigate beneath the ocean using sonar pulses."), + _chunk("c3", "Mountains rise above the clouds in the distant horizon."), + ], + quality_report=QualityReport(), + ) + run = run_retrieval_for_document(parsed, top_k=3) + self.assertTrue(run["evaluated"]) + self.assertEqual(run["query_count"], 3) + for result in run["results"]: + truth = result["truths"][0] + # Verbatim retrieval should put the source chunk at rank 1. + self.assertEqual(result["retrieved"][0], truth) + + def test_no_chunks_returns_unevaluated(self): + parsed = ParsedDocument( + doc_id="d1", + source_path="/tmp/d1.md", + file_type="markdown", + chunks=[], + quality_report=QualityReport(), + ) + run = run_retrieval_for_document(parsed) + self.assertFalse(run["evaluated"]) + self.assertEqual(run["reason"], "no_chunks") + + def test_explicit_queries_override_synthesis(self): + parsed = ParsedDocument( + doc_id="d1", + source_path="/tmp/d1.md", + file_type="markdown", + chunks=[ + _chunk("c1", "Apples grow on trees in the orchard."), + _chunk("c2", "Cars drive on highways."), + ], + quality_report=QualityReport(), + ) + queries = [RetrievalQuery(query_id="q1", text="apples orchard", truths=["c1"])] + run = run_retrieval_for_document(parsed, queries=queries) + self.assertEqual(run["query_count"], 1) + self.assertEqual(run["results"][0]["retrieved"][0], "c1") + + +class TestBenchmarkIntegration(unittest.TestCase): + def test_retrieval_metrics_appear_in_summary(self): + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + (src / "doc.md").write_text( + "# Doc\n\n" + "Apples grow on trees in the orchard during autumn.\n\n" + "Submarines navigate beneath the ocean using sonar.\n\n" + "Mountains rise above the clouds in the horizon.\n", + encoding="utf-8", + ) + + summary = run_parser_benchmark(src, tmp / "out", dataset_name="custom_folder") + + doc = summary["documents"][0] + self.assertTrue(doc["retrieval_evaluated"]) + self.assertGreater(doc["retrieval_query_count"], 0) + self.assertEqual(doc["retrieval_recall_at_1"], 1.0) + self.assertGreaterEqual(summary["mean_retrieval_recall_at_1"], 0.0) + self.assertEqual(summary["retrieval_evaluated_count"], 1) + self.assertTrue((tmp / "out" / "retrieval_runs.csv").exists()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_routing.py b/tests/test_routing.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea205420f70f6c8108cab4dbb70cce0cf5b2bd3 --- /dev/null +++ b/tests/test_routing.py @@ -0,0 +1,23 @@ +import tempfile +import unittest +from pathlib import Path + +from zsgdp.config import load_config +from zsgdp.profiling import profile_document +from zsgdp.routing import route_document + + +class RoutingTests(unittest.TestCase): + def test_markdown_routes_to_text_parser(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "sample.md" + path.write_text("# Hello\n\nWorld", encoding="utf-8") + profile = profile_document(path) + decisions = route_document(profile, load_config()) + + self.assertEqual(decisions[0].experts, ["text"]) + self.assertEqual(decisions[0].metadata["gpu_models_target"], "zeroshotGPU") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..30dc8333d6b0c33ca19ae2f119da2d5e5785cba0 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,35 @@ +import unittest + +from zsgdp.schema import Element, ParsedDocument + + +class SchemaTests(unittest.TestCase): + def test_element_content_prefers_markdown(self): + element = Element( + element_id="e1", + doc_id="d1", + page_num=1, + type="heading", + text="Heading", + markdown="## Heading", + ) + self.assertEqual(element.content(), "## Heading") + + def test_parsed_document_markdown_includes_page_boundary(self): + doc = ParsedDocument(doc_id="d1", source_path="sample.md", file_type="markdown") + doc.elements.append( + Element( + element_id="e1", + doc_id="d1", + page_num=1, + type="paragraph", + text="Hello world", + reading_order=1, + ) + ) + self.assertIn("", doc.to_markdown()) + self.assertIn("Hello world", doc.to_markdown()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_space_smoke.py b/tests/test_space_smoke.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b38c1e3ba412831ebb0e67ced6cf220b57bdb6 --- /dev/null +++ b/tests/test_space_smoke.py @@ -0,0 +1,170 @@ +"""Tests for the Space-side smoke runner.""" + +from __future__ import annotations + +import io +import json +import shutil +import tempfile +import unittest +from contextlib import redirect_stdout +from pathlib import Path +from unittest.mock import patch + +from scripts.run_space_smoke import ( + SMOKE_REGISTRY, + SmokeReport, + SmokeResult, + format_text_summary, + main, + run_smokes, + smoke_ablation, + smoke_embedding, + smoke_gpu_repair, + smoke_lexical, + smoke_marker, +) + + +class SmokeRegistryTests(unittest.TestCase): + def test_registry_lists_documented_smokes(self): + self.assertEqual( + set(SMOKE_REGISTRY), + {"lexical", "ablation", "embedding", "gpu_repair", "marker"}, + ) + + def test_each_smoke_is_callable(self): + for name, fn in SMOKE_REGISTRY.items(): + self.assertTrue(callable(fn), f"{name} is not callable") + + +class SmokeReportTests(unittest.TestCase): + def test_passed_property_is_true_when_all_pass_or_skip(self): + report = SmokeReport( + smokes=[ + SmokeResult(name="a", status="pass"), + SmokeResult(name="b", status="skip", skip_reason="dep missing"), + ] + ) + self.assertTrue(report.passed) + + def test_passed_is_false_on_any_fail_or_error(self): + report = SmokeReport( + smokes=[ + SmokeResult(name="a", status="pass"), + SmokeResult(name="b", status="fail"), + ] + ) + self.assertFalse(report.passed) + + report = SmokeReport(smokes=[SmokeResult(name="a", status="error")]) + self.assertFalse(report.passed) + + def test_to_dict_includes_summary_counts(self): + report = SmokeReport( + smokes=[ + SmokeResult(name="a", status="pass"), + SmokeResult(name="b", status="skip"), + SmokeResult(name="c", status="fail"), + SmokeResult(name="d", status="error"), + ] + ) + data = report.to_dict() + self.assertEqual(data["summary"]["total"], 4) + self.assertEqual(data["summary"]["passed"], 1) + self.assertEqual(data["summary"]["skipped"], 1) + self.assertEqual(data["summary"]["failed"], 1) + self.assertEqual(data["summary"]["errored"], 1) + + +class FormatTextSummaryTests(unittest.TestCase): + def test_marker_per_status(self): + report = SmokeReport( + smokes=[ + SmokeResult(name="a", status="pass", elapsed_seconds=0.1), + SmokeResult(name="b", status="fail", elapsed_seconds=0.2, detail={"why": "stale"}), + SmokeResult(name="c", status="skip", skip_reason="missing dep"), + SmokeResult(name="d", status="error", detail={"exception": "boom"}), + ] + ) + text = format_text_summary(report) + self.assertIn("[ok] a", text) + self.assertIn("[FAIL] b", text) + self.assertIn("reason=missing dep", text) + self.assertIn("[ERROR] d", text) + self.assertIn("smoke: FAIL", text) + + def test_strict_mode_marks_skipped_as_failure(self): + report = SmokeReport( + smokes=[ + SmokeResult(name="a", status="pass"), + SmokeResult(name="b", status="skip"), + ] + ) + self.assertIn("smoke: PASS", format_text_summary(report)) + self.assertIn("smoke: FAIL", format_text_summary(report, strict=True)) + + +class RunSmokesIntegrationTests(unittest.TestCase): + def test_lexical_passes_on_distinctive_corpus(self): + result = smoke_lexical() + self.assertEqual(result.status, "pass", msg=f"detail: {result.detail}") + self.assertGreaterEqual(result.detail["mean_quality_score"], 0.85) + self.assertGreaterEqual(result.detail["mean_retrieval_recall_at_1"], 0.7) + + def test_ablation_emits_three_arms(self): + result = smoke_ablation() + self.assertEqual(result.status, "pass", msg=f"detail: {result.detail}") + self.assertEqual(result.detail["arm_count"], 3) + self.assertSetEqual(set(result.detail["arms"]), {"text", "pymupdf", "merged"}) + self.assertTrue(result.detail["comparison_csv_emitted"]) + + def test_gpu_repair_smoke_reports_dry_run_plan(self): + result = smoke_gpu_repair() + # Either pass (transformers available, plan + iter detected) or skip + # (transformers missing on this Python). Both are acceptable. + self.assertIn(result.status, {"pass", "skip"}) + if result.status == "pass": + self.assertGreaterEqual(result.detail["dry_run_task_count"], 1) + self.assertGreaterEqual(result.detail["repair_iterations"], 1) + + def test_embedding_smoke_skips_when_sentence_transformers_missing(self): + # If sentence-transformers is genuinely installed in the test env, + # this becomes a no-op (the smoke runs and either passes or fails); + # we only assert the skip path is structured correctly. + with patch("scripts.run_space_smoke.importlib.util.find_spec", return_value=None): + result = smoke_embedding() + self.assertEqual(result.status, "skip") + self.assertIn("sentence-transformers", result.skip_reason) + self.assertIn("pip install", result.install_hint) + + def test_marker_smoke_skips_when_binary_missing(self): + with patch("scripts.run_space_smoke.shutil.which", return_value=None): + result = smoke_marker() + self.assertEqual(result.status, "skip") + self.assertIn("marker", result.skip_reason.lower()) + + +class RunSmokesDriverTests(unittest.TestCase): + def test_unknown_smoke_name_yields_error_result(self): + report = run_smokes(["nope"]) + self.assertEqual(len(report.smokes), 1) + self.assertEqual(report.smokes[0].status, "error") + self.assertIn("unknown smoke", report.smokes[0].detail["exception"]) + + def test_main_writes_json_report(self): + with tempfile.TemporaryDirectory() as tmp: + output = Path(tmp) / "report.json" + buf = io.StringIO() + with redirect_stdout(buf): + code = main(["--smoke", "ablation", "--output", str(output)]) + data = json.loads(output.read_text()) + + self.assertEqual(code, 0) + self.assertEqual(len(data["smokes"]), 1) + self.assertEqual(data["smokes"][0]["name"], "ablation") + self.assertEqual(data["summary"]["total"], 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_structure_metrics.py b/tests/test_structure_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a35288516a7e720db2396092a4bfbfbe4e1ae2 --- /dev/null +++ b/tests/test_structure_metrics.py @@ -0,0 +1,257 @@ +"""Tests for table structure similarity, formula CER, and OmniDocBench adapters.""" + +from __future__ import annotations + +import json +import tempfile +import unittest +from pathlib import Path + +from zsgdp.benchmarks.datasets import DatasetDocument, register_dataset_loader +from zsgdp.benchmarks.datasets import _LOADERS as _DATASET_LOADERS +from zsgdp.benchmarks.ground_truth import ( + omnidocbench_formula_truths, + omnidocbench_table_truths, + parsed_formula_records, + parsed_table_records, +) +from zsgdp.benchmarks.parser_quality import run_parser_benchmark +from zsgdp.schema import Element, FigureObject, ParsedDocument, QualityReport, TableObject +from zsgdp.verify.formula_extraction import compute_formula_extraction +from zsgdp.verify.table_structure import compute_table_structure_score, html_to_rows, markdown_to_rows + + +class TestMarkdownAndHTMLRows(unittest.TestCase): + def test_markdown_strips_separator_row(self): + rows = markdown_to_rows("| A | B |\n| --- | --- |\n| 1 | 2 |\n") + self.assertEqual(rows, [["a", "b"], ["1", "2"]]) + + def test_html_handles_th_and_td(self): + html = "
Col
val
" + self.assertEqual(html_to_rows(html), [["col"], ["val"]]) + + +class TestComputeTableStructure(unittest.TestCase): + def test_perfect_match(self): + truth = {"markdown": "| A | B |\n| --- | --- |\n| 1 | 2 |", "page_num": 1} + prediction = {"markdown": "| A | B |\n| --- | --- |\n| 1 | 2 |", "page_num": 1} + result = compute_table_structure_score([prediction], [truth]) + self.assertEqual(result["matched_pair_count"], 1) + self.assertEqual(result["mean_table_score"], 1.0) + self.assertEqual(result["mean_cell_content_f1"], 1.0) + self.assertEqual(result["table_match_rate"], 1.0) + + def test_partial_overlap_scores_between_zero_and_one(self): + truth = {"markdown": "| A | B |\n| --- | --- |\n| 1 | 2 |", "page_num": 1} + prediction = {"markdown": "| A | B |\n| --- | --- |\n| 1 | 3 |", "page_num": 1} + result = compute_table_structure_score([prediction], [truth]) + self.assertEqual(result["matched_pair_count"], 1) + self.assertGreater(result["mean_table_score"], 0.0) + self.assertLess(result["mean_table_score"], 1.0) + + def test_extra_prediction_lowers_match_rate(self): + truth = {"markdown": "| A |\n| --- |\n| 1 |", "page_num": 1} + predictions = [ + {"markdown": "| A |\n| --- |\n| 1 |", "page_num": 1}, + {"markdown": "| Z |\n| --- |\n| 9 |", "page_num": 2}, + ] + result = compute_table_structure_score(predictions, [truth]) + self.assertEqual(result["matched_pair_count"], 1) + self.assertEqual(result["table_match_rate"], 0.5) + self.assertEqual(result["table_count_delta"], 1) + + def test_no_matching_page_yields_no_pair(self): + truth = {"markdown": "| A |\n| --- |\n| 1 |", "page_num": 1} + prediction = {"markdown": "| A |\n| --- |\n| 1 |", "page_num": 2} + result = compute_table_structure_score([prediction], [truth]) + self.assertEqual(result["matched_pair_count"], 0) + + def test_empty_inputs_are_vacuous(self): + result = compute_table_structure_score([], []) + self.assertEqual(result["mean_table_score"], 1.0) + self.assertEqual(result["table_match_rate"], 1.0) + + +class TestComputeFormulaExtraction(unittest.TestCase): + def test_exact_match_yields_zero_cer(self): + result = compute_formula_extraction( + [{"latex": "E = mc^2", "page_num": 1}], + [{"latex": "E = mc^2", "page_num": 1}], + ) + self.assertEqual(result["mean_cer"], 0.0) + self.assertEqual(result["mean_accuracy"], 1.0) + self.assertEqual(result["exact_match_rate"], 1.0) + + def test_one_char_off_yields_proportional_cer(self): + result = compute_formula_extraction( + [{"latex": "E = mc^3", "page_num": 1}], + [{"latex": "E = mc^2", "page_num": 1}], + ) + # Levenshtein distance 1 over reference length 8 + self.assertAlmostEqual(result["mean_cer"], 1 / 8, places=6) + self.assertEqual(result["exact_match_rate"], 0.0) + + def test_empty_inputs_are_vacuous(self): + result = compute_formula_extraction([], []) + self.assertEqual(result["mean_cer"], 0.0) + self.assertEqual(result["mean_accuracy"], 1.0) + + def test_one_side_empty_yields_full_error(self): + result = compute_formula_extraction([], [{"latex": "x", "page_num": 1}]) + self.assertEqual(result["mean_cer"], 1.0) + self.assertEqual(result["mean_accuracy"], 0.0) + + def test_dollar_delimiters_stripped(self): + result = compute_formula_extraction( + [{"latex": "$$E = mc^2$$", "page_num": 1}], + [{"latex": "E = mc^2", "page_num": 1}], + ) + self.assertEqual(result["exact_match_rate"], 1.0) + + def test_greedy_matching_picks_lowest_cer_pair(self): + predictions = [ + {"latex": "E = mc^2", "page_num": 1}, + {"latex": "F = ma", "page_num": 1}, + ] + truths = [ + {"latex": "F = ma", "page_num": 1}, + {"latex": "E = mc^2", "page_num": 1}, + ] + result = compute_formula_extraction(predictions, truths) + self.assertEqual(result["matched_pair_count"], 2) + self.assertEqual(result["exact_match_rate"], 1.0) + + +class TestOmniDocBenchAdapters(unittest.TestCase): + def test_table_truths_extract_markdown_and_page(self): + gt = { + "layout_dets": [ + {"category": "table", "markdown": "| A |\n| --- |\n| 1 |", "page_num": 1}, + {"category": "Title", "text": "ignore", "page_num": 1}, + {"category": "Table", "html": "
x
", "page_num": 2}, + ] + } + truths = omnidocbench_table_truths(gt) + self.assertEqual(len(truths), 2) + self.assertEqual(truths[0]["page_num"], 1) + self.assertEqual(truths[1]["page_num"], 2) + + def test_formula_truths_extract_latex(self): + gt = { + "layout_dets": [ + {"category": "formula", "latex": "E = mc^2", "page_num": 1}, + {"category": "Equation", "text": "F = ma", "page_num": 2}, + {"category": "Title", "text": "ignore", "page_num": 1}, + ] + } + truths = omnidocbench_formula_truths(gt) + self.assertEqual(len(truths), 2) + self.assertEqual(truths[0]["latex"], "E = mc^2") + self.assertEqual(truths[1]["latex"], "F = ma") + + def test_unknown_shape_returns_empty(self): + self.assertEqual(omnidocbench_table_truths({"weird": True}), []) + self.assertEqual(omnidocbench_formula_truths({}), []) + + +class TestParsedRecords(unittest.TestCase): + def test_parsed_table_records_dedupes_object_and_element(self): + parsed = ParsedDocument( + doc_id="d1", + source_path="/tmp/d1.pdf", + file_type="pdf", + elements=[ + Element( + element_id="t1", + doc_id="d1", + page_num=1, + type="table", + markdown="| A |\n| --- |\n| 1 |", + ), + ], + tables=[ + TableObject( + table_id="t1", + page_nums=[1], + markdown="| A |\n| --- |\n| 1 |", + ), + ], + quality_report=QualityReport(), + ) + records = parsed_table_records(parsed) + # Both table objects keyed differently, so we get 2 records (table object + element). + # The dedupe key is per-source so they stay distinct, which is fine for matching. + self.assertGreaterEqual(len(records), 1) + self.assertTrue(any(record["table_id"] == "t1" for record in records)) + + def test_parsed_formula_records_extract_latex(self): + parsed = ParsedDocument( + doc_id="d1", + source_path="/tmp/d1.pdf", + file_type="pdf", + elements=[ + Element(element_id="f1", doc_id="d1", page_num=1, type="formula", text="E = mc^2"), + Element(element_id="p1", doc_id="d1", page_num=1, type="paragraph", text="not a formula"), + Element(element_id="f2", doc_id="d1", page_num=2, type="formula", text=""), + ], + quality_report=QualityReport(), + ) + records = parsed_formula_records(parsed) + self.assertEqual(len(records), 1) + self.assertEqual(records[0]["formula_id"], "f1") + self.assertEqual(records[0]["latex"], "E = mc^2") + + +class TestBenchmarkIntegration(unittest.TestCase): + def test_omnidocbench_smoke_run_emits_metrics(self): + # Use a markdown source with a one-shot loader that tags the document + # as `omnidocbench`. Lets us exercise the full benchmark wiring (table + + # formula adapters, CSVs) without needing PyMuPDF to parse bytes. + ground_truth = { + "layout_dets": [ + { + "category": "table", + "markdown": "| A | B |\n| --- | --- |\n| 1 | 2 |", + "page_num": 1, + }, + {"category": "formula", "latex": "E = mc^2", "page_num": 1}, + ] + } + + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + src = tmp / "in" + src.mkdir() + md_path = src / "doc.md" + md_path.write_text("# Doc\n\n| A | B |\n| --- | --- |\n| 1 | 2 |\n", encoding="utf-8") + + def fake_loader(root: Path): + yield DatasetDocument( + dataset_id="omnidocbench", + doc_id="doc", + path=md_path, + ground_truth=ground_truth, + metadata={}, + ) + + register_dataset_loader("omnidocbench", fake_loader) + try: + summary = run_parser_benchmark(src, tmp / "out", dataset_name="omnidocbench") + finally: + from zsgdp.benchmarks.ground_truth import omnidocbench_layout_truths + + # restore the real loader + from zsgdp.benchmarks.datasets import _load_omnidocbench + + _DATASET_LOADERS["omnidocbench"] = _load_omnidocbench + + self.assertEqual(summary["dataset_name"], "omnidocbench") + doc = summary["documents"][0] + self.assertTrue(doc["table_structure_evaluated"]) + self.assertTrue(doc["formula_evaluated"]) + self.assertTrue((tmp / "out" / "table_structure_runs.csv").exists()) + self.assertTrue((tmp / "out" / "formula_runs.csv").exists()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_text_parser.py b/tests/test_text_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..84c725a84b46a34215ebaff1ed63f4f331a7f4c6 --- /dev/null +++ b/tests/test_text_parser.py @@ -0,0 +1,25 @@ +import tempfile +import unittest +from pathlib import Path + +from zsgdp.config import load_config +from zsgdp.parsers.text_parser import TextParser +from zsgdp.profiling import profile_document + + +class TextParserTests(unittest.TestCase): + def test_text_parser_extracts_elements_and_tables(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "sample.md" + path.write_text("# Report\n\nParagraph.\n\n| A | B |\n| --- | --- |\n| 1 | 2 |\n", encoding="utf-8") + profile = profile_document(path) + + candidate = TextParser().parse(path, profile, load_config()) + + self.assertEqual(candidate.parser_name, "text") + self.assertGreaterEqual(len(candidate.elements), 2) + self.assertEqual(len(candidate.tables), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_verify.py b/tests/test_verify.py new file mode 100644 index 0000000000000000000000000000000000000000..257973ff116b08def33e0c5f87aee279dee8be6b --- /dev/null +++ b/tests/test_verify.py @@ -0,0 +1,83 @@ +import unittest + +from zsgdp.schema import DocumentProfile, Element, FigureObject, PageProfile, ParsedDocument +from zsgdp.verify import verify_parse +from zsgdp.verify.table_quality import markdown_table_is_valid + + +class VerifyTests(unittest.TestCase): + def test_verify_simple_document(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.txt", + file_type="text", + page_count=1, + extension=".txt", + pages=[PageProfile(page_num=1, digital_text_chars=11, digital_text_quality=1.0)], + ) + parsed = ParsedDocument(doc_id="d1", source_path="sample.txt", file_type="text") + parsed.elements.append( + Element( + element_id="e1", + doc_id="d1", + page_num=1, + type="paragraph", + text="hello world", + reading_order=1, + confidence=0.9, + ) + ) + + parsed.quality_report = verify_parse(profile, parsed) + + self.assertEqual(parsed.quality_report.score, 1.0) + self.assertEqual(parsed.quality_report.metrics["element_count"], 1) + + def test_markdown_table_requires_data_row(self): + self.assertFalse(markdown_table_is_valid("| A | B |\n| --- | --- |")) + self.assertTrue(markdown_table_is_valid("| A | B |\n| --- | --- |\n| 1 | 2 |")) + + def test_verify_flags_missing_figure_context(self): + profile = DocumentProfile( + doc_id="d1", + source_path="sample.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=12, digital_text_quality=1.0)], + ) + parsed = ParsedDocument(doc_id="d1", source_path="sample.pdf", file_type="pdf") + parsed.elements.append( + Element(element_id="e1", doc_id="d1", page_num=1, type="paragraph", text="hello world!", reading_order=1) + ) + parsed.figures.append(FigureObject(figure_id="f1", page_num=1)) + + report = verify_parse(profile, parsed) + + issue_types = [issue.issue_type for issue in report.issues] + self.assertIn("missing_figure_region", issue_types) + self.assertIn("missing_figure_caption", issue_types) + self.assertEqual(report.metrics["figure_description_coverage"], 0.0) + + def test_verify_flags_formula_heavy_page_without_formula_elements(self): + profile = DocumentProfile( + doc_id="d1", + source_path="math.pdf", + file_type="pdf", + page_count=1, + extension=".pdf", + pages=[PageProfile(page_num=1, digital_text_chars=24, digital_text_quality=1.0, formula_density=0.30)], + ) + parsed = ParsedDocument(doc_id="d1", source_path="math.pdf", file_type="pdf") + parsed.elements.append( + Element(element_id="e1", doc_id="d1", page_num=1, type="paragraph", text="Equation heavy page text", reading_order=1) + ) + + report = verify_parse(profile, parsed) + + self.assertIn("missing_formula_regions", [issue.issue_type for issue in report.issues]) + self.assertEqual(report.metrics["formula_page_coverage"], 0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/zero_shot_gpu_document_parser_project_spec.md b/zero_shot_gpu_document_parser_project_spec.md new file mode 100644 index 0000000000000000000000000000000000000000..c5150ec015381b179966420912069d42a333de08 --- /dev/null +++ b/zero_shot_gpu_document_parser_project_spec.md @@ -0,0 +1,2087 @@ +# Zero-Shot GPU Document Parser: Agentic Multimodal Ingestion Project Spec + +## 0. Project thesis + +Build a **self-hosted, zero-shot, GPU-accelerated document parser** that turns complex enterprise documents into auditable structured document representations and retrieval-ready chunks without fine-tuning. The parser should not rely on one extraction engine. Instead, it should behave like an **agentic parsing supervisor** that profiles each document/page, routes pages to the best open-source parser or VLM/OCR expert, verifies the result, repairs weak regions, and then runs an agentic chunking stage over the verified parse. It emits structured JSON, Markdown, table objects, figure objects, layout metadata, quality reports, provenance, and multi-strategy chunks. + +The guiding idea: + +> Do not build "one parser." Build a parsing control plane that can choose, compare, repair, and validate multiple parsers. + +--- + +## 1. Why this project exists + +Modern document automation fails when ingestion destroys document structure. The hard cases are not clean text PDFs. The hard cases are scanned PDFs, mixed digital/scanned PDFs, multi-column academic papers, financial statements, nested tables, multi-page tables, formulas, figures, charts, forms, handwritten notes, PowerPoints, spreadsheets, procedures split across pages, headers/footers/footnotes, and inconsistent reading order. + +The core production problem is: + +> Downstream systems cannot recover information that was lost, fragmented, or hallucinated during parsing. + +Chunking belongs inside the agentic pipeline because bad boundaries can destroy the structure the parser worked to recover. The chunker should use the parse tree, page provenance, bboxes, captions, tables, figures, and quality report instead of splitting raw text blindly. + +This project addresses that by treating parsing as a **multi-step decision process** instead of a single extraction call. + +--- + +## 2. Target output + +For each input document, produce a `ParsedDocument` object: + +```json +{ + "doc_id": "string", + "source_path": "string", + "file_type": "pdf | docx | pptx | xlsx | html | image | epub", + "pages": [], + "elements": [], + "tables": [], + "figures": [], + "chunks": [], + "quality_report": {}, + "provenance": {} +} +``` + +The parser should emit at least five artifacts: + +1. **Canonical Markdown** — human-readable; preserves headings, tables, equations, lists, captions, and page boundaries. +2. **Structured JSON** — element-level representation with type, page, bbox, reading order, source parser, confidence, and repair history. +3. **Specialized objects** — table objects, figure objects, formulas, captions, and layout regions. +4. **Agentic chunks** — baseline, structure-aware, parent/child, page-level, table, figure, and optional advanced chunks with provenance. +5. **Audit report** — which parser and chunking strategy was used where, quality scores, uncertain regions, repaired regions, fallback decisions. + +Embedding, vector indexing, and retrieval serving are downstream concerns, but chunk generation itself is part of this project. + +--- + +## 3. Non-goals + +Do not start by fine-tuning a custom model. + +Do not start by building a new OCR model. + +Do not start by replacing existing parsers. + +Do not start with "parse every format perfectly." + +The first version should be an orchestration, routing, verification, and repair system around strong open-source tools. + +--- + +## 4. Core open-source tools to evaluate + +### 4.1 Docling + +Use for general document conversion, layout-aware PDF parsing, reading order, table structure, formulas, multi-format ingestion, and clean JSON/Markdown export. + +Recommended role: + +```text +Default layout-aware parser for digital PDFs and office documents. +``` + +### 4.2 Marker + +Use for PDF/image/PPTX/DOCX/XLSX/HTML/EPUB to Markdown/JSON/HTML, GPU/CPU/MPS conversion, tables, forms, equations, code blocks, images, and optional LLM-assisted cleanup. + +Recommended role: + +```text +Default Markdown parser and strong alternative to Docling. +``` + +### 4.3 MinerU + +Use for scientific PDFs, textbooks, formulas, complex academic layouts, multilingual documents, Markdown/JSON conversion, and pipeline/VLM backends. + +Recommended role: + +```text +Scientific / formula-heavy / textbook parser. +``` + +### 4.4 olmOCR + +Use for scanned PDFs, image-based documents, pages where digital text extraction fails, clean linearized text in reading order, and tables/equations/handwriting cases. + +Recommended role: + +```text +VLM/OCR fallback for scanned or visually complex pages. +``` + +### 4.5 PaddleOCR / PP-StructureV3 + +Use for multilingual OCR, layout detection, table recognition, key information extraction, forms, invoices, and lightweight OCR-heavy deployments. + +Recommended role: + +```text +OCR + table/layout fallback, especially multilingual documents. +``` + +### 4.6 Unstructured + +Use for broad document ETL, element taxonomy, HTML, emails, PDFs, images, Office docs, and text/image partitioning. + +Recommended role: + +```text +General ETL fallback and connector-friendly parser. +``` + +### 4.7 Lightweight deterministic tools + +Use these as cheap first-pass tools and sanity checks: + +```text +PyMuPDF / fitz +pdfplumber +pypdf +Tesseract +MarkItDown +LibreOffice headless conversion for Office-to-PDF fallback +``` + +Recommended role: + +```text +Cheap profiling, text extraction baseline, fallback conversion, and parser sanity checks. +``` + +--- + +## 5. Open VLM/OCR model candidates + +These models are not all direct parser replacements. Use them as **page-level repair or verification experts**. + +### 5.1 Qwen2.5-VL + +Use for page image understanding, bounding-box-aware extraction, forms/tables/charts, visually complex pages, figure/chart captioning, and document QA verification. + +Recommended role: + +```text +General VLM reasoning and repair model. +``` + +### 5.2 InternVL family + +Use for multimodal reasoning, OCR/document understanding, chart/figure interpretation, and page-level verification. + +Recommended role: + +```text +Alternative VLM expert for hard-page repair and cross-checking. +``` + +### 5.3 DeepSeek-OCR + +Use for efficient OCR/VLM parsing, visual compression experiments, high-throughput page parsing, and structured Markdown recovery. + +Recommended role: + +```text +Efficient OCR-VLM backend candidate. +``` + +### 5.4 PaddleOCR-VL + +Use for multilingual document parsing, OCR + layout + table/formula/chart extraction, and lightweight VLM-style document parsing. + +Recommended role: + +```text +Multilingual OCR/VLM parser candidate. +``` + +--- + +## 6. Core architecture + +```text + ┌────────────────────┐ + │ Input Document │ + └─────────┬──────────┘ + │ + ▼ + ┌────────────────────┐ + │ Document Profiler │ + └─────────┬──────────┘ + │ + ▼ + ┌────────────────────┐ + │ Page Router │ + └─────────┬──────────┘ + │ + ┌───────────────────┼───────────────────┐ + ▼ ▼ ▼ +┌─────────────┐ ┌─────────────┐ ┌─────────────┐ +│ Docling │ │ Marker │ │ VLM/OCR │ +│ Expert │ │ Expert │ │ Expert │ +└──────┬──────┘ └──────┬──────┘ └──────┬──────┘ + └───────────────────┼───────────────────┘ + ▼ + ┌────────────────────┐ + │ Parse Merger │ + └─────────┬──────────┘ + ▼ + ┌────────────────────┐ + │ Quality Verifier │ + └─────────┬──────────┘ + ▼ + ┌────────────────────┐ + │ Repair Agent │ + └─────────┬──────────┘ + ▼ + ┌────────────────────┐ + │ Agentic Chunker │ + └─────────┬──────────┘ + ▼ + ┌────────────────────┐ + │ Artifact Export │ + └─────────┬──────────┘ + ▼ + ┌────────────────────┐ + │ ParsedDocument │ + └────────────────────┘ +``` + +--- + +## 7. Agentic parsing loop + +The system should operate page-by-page and region-by-region. + +### 7.1 High-level loop + +```python +for document in corpus: + profile = profile_document(document) + initial_plan = planner(profile) + parse_candidates = run_selected_parsers(document, initial_plan) + merged_parse = merge_candidates(parse_candidates) + quality = verify_parse(document, merged_parse) + + while quality.has_blocking_failures and budget_remaining: + repair_plan = choose_repairs(document, merged_parse, quality) + repaired_outputs = run_repairs(document, repair_plan) + merged_parse = merge_candidates([merged_parse, repaired_outputs]) + quality = verify_parse(document, merged_parse) + + chunk_plan = plan_chunks(profile, merged_parse, quality) + chunks = build_chunks(merged_parse, chunk_plan) + emit ParsedDocument(parsed=merged_parse, chunks=chunks, quality_report=quality) +``` + +### 7.2 Key principle + +The agent should not blindly trust VLM output. It should use VLMs for hard pages, repair, table linearization, figure captioning, verification questions, and conflict resolution. Use deterministic/layout parsers whenever they are good enough. + +--- + +## 8. Document profiler + +The profiler decides what the document contains and which parser experts are likely needed. + +### 8.1 Inputs + +```text +file extension +number of pages/slides/sheets +digital text availability +image coverage +OCR confidence if available +table density +formula density +figure/chart density +language/script +rotation/skew +multi-column likelihood +scanned-page likelihood +font/text layer quality +page complexity score +``` + +### 8.2 Features to compute cheaply + +For PDFs: + +```text +digital_text_chars_per_page +image_area_ratio +num_images_per_page +num_drawings_per_page +font_count +avg_chars_per_text_block +text_block_count +estimated_columns +table_candidate_count +has_embedded_text_layer +page_rotation +``` + +For images/scans: + +```text +resolution +blur score +contrast score +skew estimate +OCR confidence from cheap OCR +``` + +For Office docs: + +```text +native structure availability +tables count +images count +slides count +speaker notes +sheet count +merged cells +``` + +### 8.3 Page type labels + +```text +digital_text +scanned_text +multi_column +table_heavy +formula_heavy +figure_heavy +chart_heavy +form_like +handwritten +slide_like +spreadsheet_like +low_confidence +``` + +--- + +## 9. Parser router + +The router maps page/document profile to parser experts. + +### 9.1 Routing table + +```text +Clean digital PDF: + primary: Docling + backup: Marker + repair: none unless verifier fails + +Scientific PDF: + primary: MinerU or Marker + backup: Docling + repair: VLM formula/table pass + +Scanned PDF: + primary: olmOCR or PaddleOCR + backup: MinerU VLM backend + repair: Qwen2.5-VL page pass + +Table-heavy financial PDF: + primary: Docling + Marker + backup: PaddleOCR/PP-StructureV3 + repair: table structure agent + natural-language table renderer + +Slide deck: + primary: Docling or Marker + backup: native PPTX extraction + repair: VLM figure/chart captioning + +Spreadsheet: + primary: native XLSX parser + backup: Marker + repair: table summarizer + +Form/invoice: + primary: PaddleOCR/PP-StructureV3 + backup: Docling + repair: VLM key-value extraction +``` + +### 9.2 Router output + +```json +{ + "page_id": 12, + "experts": ["docling", "marker", "qwen_vl_repair"], + "reason": "table-heavy scanned page with low text-layer confidence", + "budget": { + "max_gpu_seconds": 6, + "max_retries": 2 + } +} +``` + +--- + +## 10. Parse candidate schema + +All parsers must be normalized into a common schema. + +```python +@dataclass +class Element: + element_id: str + doc_id: str + page_num: int + type: str # title, paragraph, table, figure, formula, list_item, header, footer, caption + text: str | None + markdown: str | None + html: str | None + bbox: tuple[float, float, float, float] | None + reading_order: int | None + confidence: float | None + source_parser: str + provenance: dict +``` + +```python +@dataclass +class TableObject: + table_id: str + page_nums: list[int] + bbox: list[tuple] | None + html: str | None + markdown: str | None + dataframe_json: dict | None + natural_language_rendering: str | None + caption: str | None + footnotes: list[str] + confidence: float + source_parser: str +``` + +```python +@dataclass +class FigureObject: + figure_id: str + page_num: int + bbox: tuple | None + image_path: str | None + caption: str | None + vlm_description: str | None + chart_data: dict | None + confidence: float + source_parser: str +``` + +--- + +## 11. Parse merger + +When multiple parsers are run, the merger decides which elements to keep. + +### 11.1 Alignment keys + +```text +page number +bbox overlap +text similarity +element type +reading order +table/figure captions +``` + +### 11.2 Merge rules + +Prefer deterministic/layout parser when: + +```text +digital text confidence high +bbox is available +reading order consistent +table HTML is structurally valid +``` + +Prefer VLM/OCR parser when: + +```text +digital text layer is missing +OCR confidence from deterministic tool is low +page is scanned +layout parser misses obvious regions +``` + +Prefer table-specific parser when: + +```text +table structure is valid +row/column count is consistent +header hierarchy preserved +``` + +### 11.3 Conflict cases + +Flag conflicts when: + +```text +two parsers disagree strongly on reading order +table row/column counts differ +text coverage differs by more than threshold +OCR text disagrees with digital text +figure/caption alignment differs +``` + +Conflict resolution should call a verification or repair agent, not silently choose. + +--- + +## 12. Quality verifier + +The verifier scores parse quality at page, region, and document level. + +### 12.1 Coverage checks + +```text +text coverage ratio +image coverage ratio +bbox coverage ratio +number of dropped pages +number of empty pages +ratio of parser elements to visual regions +``` + +### 12.2 Reading order checks + +```text +multi-column order consistency +heading followed by body text +caption near figure/table +page number/header/footer not inserted into body +lists remain contiguous +procedures remain ordered +``` + +### 12.3 Table checks + +```text +valid Markdown table +valid HTML table +consistent row lengths +header rows detected +merged cells represented +numeric columns preserved +footnotes attached +multi-page continuation detected +``` + +### 12.4 Formula checks + +```text +formula regions detected +LaTeX generated where possible +inline vs block formula distinction +symbol corruption score +``` + +### 12.5 Figure/chart checks + +```text +figure bbox detected +caption aligned +chart axis labels extracted +legend extracted +VLM description generated +``` + +### 12.6 Chunk-readiness checks + +```text +chunk length distribution +chunk boundary quality +parent/child link validity +orphan tables +orphan captions +chunks without page provenance +duplicate chunks +missing section headings +missing parser provenance +missing confidence scores +``` + +--- + +## 13. Repair agents + +Each repair agent receives a specific failure report and a cropped page/region if needed. + +### 13.1 Reading order repair + +Input: + +```text +page image +candidate elements with bboxes +current reading order +``` + +Output: + +```text +reordered element ids +confidence +explanation +``` + +Use when multi-column order is inconsistent, headings/captions are out of place, or procedure steps are fragmented. + +### 13.2 Table repair + +Input: + +```text +table crop +candidate markdown/html +surrounding caption/text +``` + +Output: + +```text +corrected HTML table +corrected Markdown table +natural-language rendering +confidence +``` + +Important: store multiple representations: + +```text +HTML for structure +Markdown for LLM readability +DataFrame/JSON for exact querying +natural-language prose for semantic retrieval +``` + +Natural-language table rendering example: + +```text +The table compares Q1 and Q2 revenue by region. North America increased from $10.2M to $12.1M, Europe declined from $8.4M to $7.9M, and APAC increased from $5.5M to $6.3M. +``` + +### 13.3 Figure/chart repair + +Input: + +```text +figure crop +caption candidate +nearby text +``` + +Output: + +```text +figure caption +chart summary +axis labels +legend labels +approximate extracted values if possible +``` + +Use VLMs here. + +### 13.4 OCR repair + +Input: + +```text +page image or region crop +bad OCR text +``` + +Output: + +```text +corrected text +reading order +confidence +``` + +Use when character noise is high, handwriting appears, scanned pages are present, or the text layer is corrupted. + +### 13.5 Header/footer cleanup + +Input: + +```text +elements across pages +bbox positions +text repetition stats +``` + +Output: + +```text +elements marked as header/footer/page_number/noise +``` + +This should be mostly deterministic. + +### 13.6 Section hierarchy repair + +Input: + +```text +elements with styles, font sizes, positions, text +``` + +Output: + +```text +section tree +heading levels +parent_id for each element +``` + +Use deterministic rules first. Use LLM only for ambiguous cases. + +--- + +## 14. Agentic controller + +### 14.1 Controller state + +```python +@dataclass +class ParseState: + doc_profile: dict + page_profiles: list[dict] + parse_candidates: dict + merged_parse: ParsedDocument + quality_report: QualityReport + repair_history: list[RepairAction] + budget: Budget +``` + +### 14.2 Controller actions + +```text +RUN_PARSER(parser, pages) +RUN_VLM_REPAIR(task, page_or_region) +RUN_TABLE_REPAIR(table_region) +RUN_READING_ORDER_REPAIR(page) +RUN_HEADER_FOOTER_CLEANUP(document) +MERGE_CANDIDATES +VERIFY +STOP_ACCEPT +STOP_FAIL_WITH_WARNINGS +``` + +### 14.3 Controller policy + +Use a deterministic policy first. + +```python +if page.is_scanned and text_coverage < 0.4: + run("olmocr", page) + +if page.table_density_high and table_quality_low: + run("table_repair", table_regions) + +if reading_order_score < threshold: + run("reading_order_repair", page) + +if figure_count > 0 and figure_descriptions_missing: + run("figure_caption_repair", figures) + +if quality_score >= accept_threshold: + stop_accept() +``` + +Later, replace this with a learned policy or bandit router. + +--- + +## 15. GPU execution plan + +### 15.1 Why GPU is needed + +GPU should be used for: + +```text +VLM/OCR inference +layout detection models +table recognition models +formula recognition +image/page rendering batches +semantic, late, vision-guided, and agentic chunking when enabled +``` + +CPU can handle: + +```text +PDF metadata +digital text extraction +Markdown cleanup +schema validation +header/footer detection +baseline and structure-aware chunk construction +``` + +### 15.2 Execution model + +Use a queue-based architecture: + +```text +CPU coordinator +GPU parser workers +GPU VLM repair workers +CPU merger/verifier workers +object store for page images/crops +``` + +### 15.3 Suggested libraries + +```text +Ray for distributed task scheduling +Celery/RQ for simpler queue +vLLM for supported VLM/OCR models +SGLang as alternative serving backend +Transformers for quick prototypes +PyTorch for model execution +CUDA image preprocessing where useful +``` + +### 15.4 Page batching + +Batch pages by: + +```text +same model +same resolution bucket +same task type +same prompt template +``` + +Avoid sending one page at a time to GPU if throughput matters. + +### 15.5 GPU budget + +Each document/page should carry a budget: + +```json +{ + "max_gpu_seconds": 60, + "max_vlm_calls": 20, + "max_repair_iterations": 3, + "allow_expensive_models": false +} +``` + +The controller should escalate only when cheaper parsers fail. + +--- + +## 16. Zero-shot prompting templates + +### 16.1 Page-to-Markdown repair + +```text +You are repairing a parsed document page. Convert the provided page image into faithful Markdown. +Preserve: +- headings +- paragraphs +- reading order +- tables +- equations +- figure captions +- lists +Do not summarize. Do not invent missing content. Mark uncertain text as [UNCERTAIN: ...]. +Return only Markdown. +``` + +### 16.2 Table repair + +```text +You are repairing a table extraction. +Given the table image and the candidate extraction, return: +1. valid HTML table +2. Markdown table +3. concise natural-language rendering preserving row-column relationships +4. notes on uncertainty + +Do not invent values. Preserve units, footnotes, headers, and merged-cell meaning. +``` + +### 16.3 Figure/chart description + +```text +Describe this figure for document understanding and downstream chunking. +Extract: +- figure type +- title/caption +- axes and units if chart +- legend labels +- main trend or relationship +- any visible numeric values +Do not overstate approximate values. Mark approximations explicitly. +``` + +### 16.4 Reading order repair + +```text +Given page element boxes and the page image, produce the correct reading order. +Return element IDs only in order. +Keep captions near their figures/tables. +Ignore repeated page headers/footers unless they contain unique content. +``` + +--- + +## 17. Agentic chunking strategy + +Chunking should happen **after** parsing, verification, and repair, but still inside the document-processing pipeline. The chunker should consume the verified parse tree instead of raw extracted text. + +The chunker is not a single algorithm. It is a planner that chooses a strategy ladder based on document type, structure quality, visual density, expected query style, budget, and embedding backend. + +### 17.1 Strategy ladder + +Generation 1 baselines: + +```text +fixed_token_baseline: + split every N tokens with optional overlap + used as the measurement baseline + +recursive_structure: + recursive paragraph/newline/sentence/word splitting + default target: 400-512 tokens + default overlap: 10-20% + +document_structure_aware: + split on Markdown/HTML/LaTeX headings, pages, sections, tables, figures, captions + preferred whenever canonical Markdown or layout structure is available +``` + +Generation 2 embedding-aware options: + +```text +semantic_chunking: + place boundaries where adjacent sentence embeddings diverge + use only when benchmarks on the target corpus justify the extra compute + +cluster_semantic_chunking: + group related sentences or paragraphs using embedding clusters + useful for long multi-topic narrative documents +``` + +Generation 3 context-preserving options: + +```text +late_chunking: + embed the full document with a long-context embedding model + mean-pool token embeddings into chunk vectors after the model has seen full context + enable only when the embedding backend supports it + +contextual_retrieval: + prepend a generated context sentence to each chunk before embedding + useful when chunks need to be self-describing + +parent_child_retrieval: + create small child chunks for precise search + link each child to larger parent chunks for generation context + default advanced strategy because it is cheap and high ROI + +sentence_window: + embed sentences and return a window around matching sentences + useful for procedural and financial-document QA +``` + +Generation 3.5 agentic and multimodal options: + +```text +agentic_proposition_chunking: + use an LLM to decompose high-value documents into atomic propositions + reserve for low-volume, high-stakes corpora + +vision_guided_chunking: + use rendered pages, bboxes, tables, figures, and VLM region understanding + use for visually rich documents such as financial filings, scientific papers, and slide decks + +question_aware_chunking: + optimize chunk boundaries for the expected query distribution + use after query logs or a synthetic QA set exists +``` + +### 17.2 Default production policy + +Start conservative: + +```text +1. Always emit fixed-size/token baseline chunks for benchmarking. +2. Use recursive structure-aware chunks at 400-512 tokens with 10-20% overlap as the default. +3. Split on headings, pages, tables, figures, captions, and list/procedure boundaries whenever the parser recovered them. +4. Emit parent/child chunks by default: + - child chunks: 100-500 tokens + - parent chunks: 500-2000 tokens +5. Emit table chunks from TableObject, not flattened prose only. +6. Emit figure chunks from FigureObject, captions, VLM descriptions, and crop references. +7. Escalate to semantic, late, contextual, agentic, or vision-guided chunking only when the profiler and benchmarks justify the cost. +``` + +### 17.3 Chunk planner inputs + +```text +file_type +page_count +section hierarchy confidence +table density +figure/chart density +formula density +multi-page table likelihood +procedure/list density +visual complexity score +quality_report failures +embedding backend capabilities +expected query type if known +latency/GPU budget +``` + +### 17.4 Chunk schema + +```json +{ + "chunk_id": "string", + "doc_id": "string", + "page_start": 1, + "page_end": 2, + "section_path": ["3. Benefits", "3.2 Eligibility"], + "element_ids": [], + "table_ids": [], + "figure_ids": [], + "content_type": "parent | prose | table | figure | formula | page | mixed", + "strategy": "recursive_structure | parent_child | page_level | table_object | vision_guided", + "parent_chunk_id": "string | null", + "child_chunk_ids": [], + "boundary_reason": "heading_boundary | page_boundary | token_budget | semantic_shift | visual_region", + "token_count": 420, + "source_parser": "docling+repair", + "quality_score": 0.91, + "requires_visual_context": false, + "context_prefix": "optional generated contextual sentence", + "metadata": {} +} +``` + +### 17.5 Table chunking + +For each table, store: + +1. exact structured table; +2. table Markdown; +3. table HTML where available; +4. natural-language rendering; +5. surrounding caption and paragraph context; +6. links to page numbers, bboxes, footnotes, and continuation pages. + +Do not rely only on a flat CSV/JSON representation. + +### 17.6 Figure chunking + +For each figure, store: + +1. caption; +2. VLM-generated description; +3. linked nearby prose; +4. image path/crop reference; +5. axes, legend, and chart values if extracted; +6. `requires_visual_context=true` when text alone is insufficient. + +### 17.7 Chunking evaluation rule + +There is no single best chunking strategy. Every corpus should benchmark at least: + +```text +fixed 200-word or fixed-token baseline +recursive 400-512 token baseline +structure-aware Markdown/page chunks +parent-child chunks +page-level chunks for paginated PDFs +semantic chunking if embeddings are available +late chunking if the embedding backend supports long-context pooling +vision-guided chunks for visually rich pages +agentic/proposition chunks for high-value low-volume documents +``` + +--- + +## 18. Benchmarking and evaluation + +### 18.1 Parser-level metrics + +```text +text coverage +character error rate where ground truth exists +reading order accuracy +layout element F1 +table detection F1 +table structure similarity +formula extraction accuracy +figure/caption alignment accuracy +parse latency +GPU seconds per page +cost per 1k pages +``` + +### 18.2 Chunking and retrieval-readiness metrics + +```text +retrieval recall@k +answer exactness +citation accuracy +table QA accuracy +figure QA accuracy +multi-page procedure QA +abstention correctness +parent-child resolution accuracy +chunk boundary precision +chunk provenance completeness +latency per chunking strategy +token overhead per chunking strategy +``` + +### 18.3 Quality-report metrics + +```text +percentage pages accepted first-pass +percentage pages requiring repair +repair success rate +parser disagreement rate +manual review rate +blocking failure rate +``` + +--- + +## 19. Benchmarks and datasets + +### 19.1 General document parsing + +Use: + +```text +OmniDocBench +DocLayNet +custom enterprise-style PDF set +``` + +### 19.2 Tables + +Use: + +```text +PubTabNet +FinTabNet / FinTabNet.c +PubTables-style datasets +custom financial statement tables +``` + +### 19.3 RAG-oriented evaluation + +Create your own QA set over: + +```text +HR policy documents +IT runbooks +financial reports +scientific papers +technical manuals +slide decks +spreadsheets +``` + +For each document, create questions requiring: + +```text +single paragraph lookup +table lookup +cross-page table lookup +figure interpretation +multi-section synthesis +procedure reconstruction +``` + +--- + +## 20. Testing strategy + +### 20.1 Unit tests + +```text +test_document_profile_detects_scanned_page +test_router_selects_ocr_for_scanned_page +test_router_selects_table_repair_for_bad_table +test_parser_outputs_normalized_schema +test_merger_deduplicates_overlapping_elements +test_header_footer_cleanup_removes_repeated_noise +test_quality_verifier_flags_empty_page +test_quality_verifier_flags_invalid_table +test_chunker_preserves_page_provenance +``` + +### 20.2 Integration tests + +```text +digital_pdf_clean +scanned_pdf +multi_column_paper +financial_report_with_tables +slide_deck_with_charts +spreadsheet_with_merged_cells +mixed_pdf_digital_and_scanned +``` + +### 20.3 Regression tests + +Keep a gold set of 50-100 pages. Every parser change should compare: + +```text +Markdown diff +JSON schema validity +table count +figure count +reading order score +RAG QA score +latency +GPU memory +``` + +--- + +## 21. MVP implementation plan + +### Phase 1: Parser harness + +Build wrappers for: + +```text +Docling +Marker +PyMuPDF fallback +``` + +Then add: + +```text +MinerU +olmOCR +PaddleOCR +Unstructured +``` + +All wrappers must output the common schema. + +Deliverables: + +```text +parser_wrappers/ +normalized_schema.py +parse_one.py +sample outputs +``` + +### Phase 2: Profiler + router + +Implement: + +```text +PDF/text-layer profiler +page image renderer +scanned-page detector +table-heavy detector +formula-heavy heuristic +multi-column heuristic +router decision table +``` + +Deliverables: + +```text +profile_document() +route_pages() +routing_report.json +``` + +### Phase 3: Merger + verifier + +Implement: + +```text +element alignment +duplicate removal +coverage checks +table validity checks +reading order heuristics +quality report +``` + +Deliverables: + +```text +merge_candidates() +verify_parse() +quality_report.json +``` + +### Phase 4: Repair loop + +Implement repair agents for: + +```text +table repair +reading order repair +figure/chart description +OCR repair +header/footer cleanup +``` + +Start with local VLM prompts. Keep every repair action logged. + +Deliverables: + +```text +repair_controller.py +repair_history.json +before_after_examples/ +``` + +### Phase 5: Agentic chunking + +Implement: + +```text +fixed-token baseline chunks +recursive structure-aware chunks +parent/child chunk hierarchy +page-level chunks for paginated documents +table multi-representation chunks +figure/visual-context chunks +chunk planner with strategy metadata +optional contextual-prefix hooks +optional semantic, late, vision-guided, and proposition chunking hooks +``` + +Deliverables: + +```text +chunks.jsonl +elements.jsonl +tables.jsonl +figures.jsonl +chunking_plan.json +``` + +### Phase 6: Benchmark suite + +Implement: + +```text +parse-quality benchmarks +chunking/retrieval-readiness benchmarks +latency/GPU benchmarks +parser ablations +chunking strategy ablations +agent-loop ablations +``` + +Deliverables: + +```text +benchmark_parser_quality.py +benchmark_structure_quality.py +benchmark_chunking_quality.py +benchmark_throughput.py +leaderboard.csv +``` + +--- + +## 22. Suggested repository structure + +```text +zero_shot_gpu_doc_parser/ + README.md + pyproject.toml + configs/ + default.yaml + parsers.yaml + routing.yaml + gpu.yaml + + zsgdp/ + __init__.py + + schema/ + document.py + elements.py + tables.py + figures.py + chunks.py + quality.py + + profiling/ + pdf_profile.py + page_render.py + scan_detect.py + layout_heuristics.py + + routing/ + router.py + policies.py + budgets.py + + parsers/ + base.py + docling_parser.py + marker_parser.py + mineru_parser.py + olmocr_parser.py + paddleocr_parser.py + unstructured_parser.py + pymupdf_parser.py + + normalize/ + normalize_docling.py + normalize_marker.py + normalize_mineru.py + normalize_olmocr.py + normalize_paddleocr.py + normalize_unstructured.py + + merge/ + align.py + dedupe.py + merge_candidates.py + conflict_detection.py + + verify/ + coverage.py + reading_order.py + table_quality.py + figure_quality.py + formula_quality.py + chunk_readiness.py + quality_report.py + + repair/ + controller.py + prompts.py + table_repair.py + reading_order_repair.py + figure_repair.py + ocr_repair.py + header_footer.py + + chunking/ + hierarchy.py + planner.py + splitters.py + chunker.py + table_chunks.py + figure_chunks.py + contextual.py + semantic.py + late.py + vision_guided.py + agentic.py + + gpu/ + worker.py + batching.py + model_server.py + vllm_client.py + transformers_client.py + + benchmarks/ + datasets.py + parser_quality.py + chunking_quality.py + structure_quality.py + throughput.py + ablations.py + + tests/ + test_schema.py + test_profiler.py + test_router.py + test_merge.py + test_verify.py + test_chunking.py + test_end_to_end.py + + examples/ + parse_pdf.py + parse_folder.py + run_benchmark.py +``` + +--- + +## 23. CLI design + +### Parse one document + +```bash +zsgdp parse \ + --input ./docs/report.pdf \ + --output ./out/report \ + --config configs/default.yaml +``` + +### Parse folder + +```bash +zsgdp parse-folder \ + --input ./docs \ + --output ./parsed \ + --workers 8 \ + --gpu-workers 2 +``` + +### Benchmark + +```bash +zsgdp benchmark \ + --dataset omnidocbench \ + --parsers docling marker mineru olmocr \ + --output ./benchmarks/results +``` + +### Chunk artifact export + +```bash +zsgdp export-chunks \ + --parsed ./parsed/report \ + --format jsonl \ + --output ./chunks/report_chunks.jsonl +``` + +--- + +## 24. Configuration example + +```yaml +parsers: + docling: + enabled: true + device: cuda + marker: + enabled: true + device: cuda + mineru: + enabled: true + backend: pipeline + device: cuda + olmocr: + enabled: true + device: cuda + paddleocr: + enabled: true + device: cuda + unstructured: + enabled: true + pymupdf: + enabled: true + +routing: + run_multiple_on_hard_pages: true + max_primary_parsers_per_page: 2 + hard_page_threshold: 0.65 + scanned_text_threshold: 0.40 + table_density_threshold: 0.25 + +repair: + enabled: true + max_iterations: 3 + table_repair: true + reading_order_repair: true + figure_repair: true + ocr_repair: true + +gpu: + backend: transformers + batch_pages: true + max_batch_size: 4 + max_gpu_seconds_per_doc: 120 + max_vlm_calls_per_doc: 30 + +quality: + accept_threshold: 0.88 + blocking_failures: + - empty_page + - invalid_table + - missing_text_coverage + - reading_order_failure + +chunking: + enabled: true + planner: agentic + baseline_strategy: recursive_structure + target_tokens: 512 + min_tokens: 120 + overlap_ratio: 0.15 + parent_child: true + parent_target_tokens: 1600 + page_level_for_paginated_docs: true + table_chunks: true + figure_chunks: true + contextual_prefix: false + semantic_chunking: false + late_chunking: false + vision_guided: false + agentic_proposition_chunking: false + strategy_ladder: + - fixed_token_baseline + - recursive_structure + - metadata_enriched + - parent_child + - late_chunking + - semantic_chunking + - vision_guided + - agentic_proposition +``` + +--- + +## 25. Core algorithms + +### 25.1 Page routing algorithm + +```python +def route_page(page_profile): + experts = [] + + if page_profile.scanned_score > 0.7: + experts += ["olmocr", "paddleocr"] + + if page_profile.table_density > 0.25: + experts += ["docling", "marker", "paddleocr"] + + if page_profile.formula_density > 0.15: + experts += ["mineru", "marker"] + + if page_profile.figure_density > 0.2: + experts += ["docling", "marker"] + experts += ["vlm_figure_repair"] + + if page_profile.digital_text_quality > 0.8: + experts += ["docling"] + + if not experts: + experts = ["docling", "marker"] + + return dedupe_and_budget(experts) +``` + +### 25.2 Repair loop algorithm + +```python +def repair_loop(document, merged_parse, quality_report, budget): + while quality_report.has_blocking_failures() and budget.remaining(): + failures = quality_report.get_failures() + repair_actions = [] + + for failure in failures: + if failure.type == "invalid_table": + repair_actions.append(TableRepair(failure.region_id)) + elif failure.type == "reading_order_failure": + repair_actions.append(ReadingOrderRepair(failure.page_num)) + elif failure.type == "missing_figure_caption": + repair_actions.append(FigureRepair(failure.figure_id)) + elif failure.type == "low_ocr_confidence": + repair_actions.append(OCRRepair(failure.page_num)) + + repair_outputs = run_repair_actions(repair_actions, budget) + merged_parse = merge_candidates([merged_parse, repair_outputs]) + quality_report = verify_parse(document, merged_parse) + + return merged_parse, quality_report +``` + +### 25.3 Chunk planning algorithm + +```python +def plan_chunks(profile, parsed_document, quality_report, config): + plan = ["fixed_token_baseline", "recursive_structure"] + + if parsed_document.has_headings or parsed_document.file_type in ["markdown", "html"]: + plan.append("document_structure_aware") + + if config.chunking.parent_child: + plan.append("parent_child") + + if profile.file_type == "pdf" and profile.page_count > 1: + plan.append("page_level") + + if profile.table_density_high: + plan.append("table_object_chunks") + + if profile.figure_density_high: + plan.append("figure_object_chunks") + + if config.embedding_backend.supports_late_chunking and profile.long_cross_referenced_doc: + plan.append("late_chunking") + + if profile.long_multi_topic_narrative and budget.allows_embeddings: + plan.append("semantic_chunking") + + if profile.visual_complexity_high and budget.allows_vlm: + plan.append("vision_guided_chunking") + + if profile.high_value_low_volume and budget.allows_llm: + plan.append("agentic_proposition_chunking") + + return plan +``` + +### 25.4 Chunk construction algorithm + +```python +def build_chunks(parsed_document, chunk_plan): + parent_chunks = split_by_section_page_and_parent_budget(parsed_document) + child_chunks = [] + + for parent in parent_chunks: + children = recursive_split( + parent.text, + target_tokens=512, + overlap_ratio=0.15, + preserve_boundaries=["heading", "list", "table", "figure", "caption"] + ) + link_parent_child(parent, children) + child_chunks.extend(children) + + table_chunks = build_table_chunks(parsed_document.tables) + figure_chunks = build_figure_chunks(parsed_document.figures) + page_chunks = build_page_chunks(parsed_document.pages) + + return parent_chunks + child_chunks + table_chunks + figure_chunks + page_chunks +``` + +### 25.5 Table natural-language rendering + +```python +def render_table_as_text(table): + prompt = f''' + Convert this table into concise natural-language statements. + Preserve row-column relationships, units, headers, and footnotes. + Do not summarize away important numeric values. + + HTML: + {table.html} + ''' + return vlm_or_llm(prompt) +``` + +Use this only after validating the table structure. + +--- + +## 26. Output format examples + +### 26.1 Element JSONL + +```json +{"element_id":"e1","type":"title","page_num":1,"text":"Benefits Eligibility","bbox":[72,91,430,120],"source_parser":"docling","confidence":0.94} +``` + +### 26.2 Table JSONL + +```json +{"table_id":"t4","page_nums":[12,13],"caption":"Quarterly Revenue by Region","html":"...
","markdown":"| Region | Q1 | Q2 |","natural_language_rendering":"The table compares Q1 and Q2 revenue by region...","confidence":0.88} +``` + +### 26.3 Chunk JSONL + +```json +{"chunk_id":"c18","doc_id":"policy_001","page_start":4,"page_end":5,"section_path":["Leave Policy","Eligibility"],"content_type":"prose","strategy":"recursive_structure","parent_chunk_id":"c17","boundary_reason":"recursive_separator","token_count":412,"text":"Employees are eligible for...","element_ids":["e31","e32"],"table_ids":[],"figure_ids":[],"quality_score":0.92} +``` + +--- + +## 27. Model and parser selection policy + +### First implementation stack + +Use: + +```text +Docling +Marker +PyMuPDF +one VLM/OCR repair model +``` + +Then add: + +```text +MinerU +olmOCR +PaddleOCR +Unstructured +``` + +Why not integrate everything immediately? + +Because the hard part is not calling parsers. The hard part is: + +```text +normalization +merging +verification +repair routing +benchmarking +``` + +The MVP should prove the control plane works before adding every parser. + +--- + +## 28. Evaluation ablations + +Run these ablations: + +```text +Docling only +Marker only +MinerU only +olmOCR only on scans +router without repair +router with repair +router with table natural-language rendering +router with figure descriptions +full agentic loop +fixed-token chunks +recursive-structure chunks +page-level chunks +parent-child chunks +semantic chunks where enabled +late chunks where enabled +vision-guided chunks where enabled +agentic proposition chunks where enabled +``` + +Measure: + +```text +parse quality +chunk boundary quality +retrieval readiness +answer quality +latency +GPU seconds +manual review rate +token overhead +``` + +The key question: + +> Does the agentic parse-and-chunk loop outperform the best single parser plus baseline chunker enough to justify its overhead? + +--- + +## 29. Success criteria + +### MVP success + +The system is worth continuing if: + +```text +full agentic loop improves table QA by >= 20% over best single parser +agentic chunking improves citation accuracy by >= 10% over recursive baseline +parent-child chunks improve multi-section answers without excessive latency +manual inspection shows fewer reading-order failures +latency remains within acceptable budget +``` + +### Production-style success + +For a scoped workflow, e.g. HR policy or financial reports: + +```text +retrieval recall@5 >= 90% +citation accuracy >= 90% +table QA exactness >= 85% +manual review rate <= 10% +parser blocking failure rate <= 5% +``` + +--- + +## 30. Risks and mitigations + +### Risk 1: VLM hallucination + +Mitigation: + +```text +never let VLM overwrite deterministic text without provenance +mark uncertain output +require consistency checks +use cropped regions instead of full-page vague prompts +store parser disagreement +``` + +### Risk 2: table repair invents values + +Mitigation: + +```text +extract table image crop +require cell-by-cell structure +compare numeric strings against OCR/digital text +flag unverified cells +store original table crop +``` + +### Risk 3: agent loop is too slow + +Mitigation: + +```text +use cheap profiler first +run expensive VLM only on hard pages +batch page images +cache parser outputs +respect GPU budgets +``` + +### Risk 4: too many parser dependencies + +Mitigation: + +```text +plugin architecture +optional parser extras +Docker profiles +minimal MVP stack +``` + +### Risk 5: benchmarking becomes subjective + +Mitigation: + +```text +create gold-page set +use parser-level metrics +use RAG QA metrics +store before/after examples +track exact failure types +``` + +--- + +## 31. Immediate coding milestones + +### Week 1 + +```text +Create schema objects +Implement PyMuPDF profiler +Implement Docling wrapper +Implement Marker wrapper +Normalize outputs to common schema +Parse 10 sample PDFs +``` + +### Week 2 + +```text +Implement router +Implement quality verifier +Implement table validity checks +Implement reading-order heuristics +Implement simple merge logic +``` + +### Week 3 + +```text +Add VLM repair backend +Implement table repair prompt +Implement figure caption prompt +Implement repair loop +Emit quality reports +``` + +### Week 4 + +```text +Build agentic chunk planner +Add fixed-token and recursive-structure baselines +Add parent-child chunk hierarchy +Add page-level, table, and figure chunk representations +Run chunking ablations on small QA set +Add benchmark scripts +``` + +### Week 5+ + +```text +Add MinerU +Add olmOCR +Add PaddleOCR +Add micro-batching / Ray GPU workers +Add OmniDocBench/PubTabNet-style evaluation +``` + +--- + +## 32. Final system claim + +The final project should claim: + +> We built a zero-shot, self-hosted, GPU-accelerated document parsing and chunking control plane. It routes pages to the best open-source parser, verifies structure, repairs weak regions with local VLM/OCR calls, and emits auditable, multi-representation document objects plus strategy-aware chunks. + +Do not claim: + +> We invented a better OCR model. + +Claim: + +> We made open-source document parsing and chunking operational, adaptive, and measurable for complex downstream retrieval pipelines. + +--- + +## 33. Reference shortlist + +Primary tools: + +- Docling: open-source document conversion, layout/table/formula support, JSON/Markdown export. +- Marker: local GPU/CPU/MPS document-to-Markdown/JSON/HTML parser. +- MinerU: open-source parser for PDFs, images, DOCX, PPTX, XLSX, strong for scientific/symbolic documents. +- olmOCR: Allen AI open-source OCR/VLM toolkit for linearizing PDFs and image documents. +- PaddleOCR / PP-StructureV3: multilingual OCR and hierarchical document parsing. +- Unstructured: open-source document ETL and element extraction library. + +Evaluation: + +- OmniDocBench: diverse PDF document parsing benchmark. +- DocLayNet: layout segmentation dataset with 80k+ annotated pages. +- PubTabNet: large table recognition dataset with 568k+ table images. +- FinTabNet / FinTabNet.c: financial table structure recognition datasets. + +Model backends: + +- Qwen2.5-VL +- InternVL +- DeepSeek-OCR +- PaddleOCR-VL +- olmOCR models + +--- + +## 34. Deployment target + +Targeted deployment: Hugging Face Spaces. + +GPU/model target: Hugging Face Spaces GPU/Models - zeroshotGPU. diff --git a/zsgdp/__init__.py b/zsgdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1bcb12ea74fc62038f5e50f3f0e66357408e48eb --- /dev/null +++ b/zsgdp/__init__.py @@ -0,0 +1,8 @@ +"""Zero-shot GPU document parser MVP.""" + +from zsgdp.pipeline import parse_document +from zsgdp.profiling.document_profile import profile_document + +__all__ = ["parse_document", "profile_document"] + +__version__ = "0.1.0" diff --git a/zsgdp/artifacts.py b/zsgdp/artifacts.py new file mode 100644 index 0000000000000000000000000000000000000000..1c983a3c0251b1de03f57d6d8e8eb02169ae22ec --- /dev/null +++ b/zsgdp/artifacts.py @@ -0,0 +1,115 @@ +"""Artifact manifest generation and validation.""" + +from __future__ import annotations + +from datetime import datetime, timezone +import hashlib +import json +from pathlib import Path +from typing import Any + +from zsgdp.schema import SCHEMA_VERSION, ParsedDocument +from zsgdp.utils import write_json + +MANIFEST_FILENAME = "artifact_manifest.json" + +# The manifest format itself has a separate version from the parsed-document +# schema it describes. Bump this when the manifest's own keys change. +MANIFEST_SCHEMA_VERSION = 1 + + +def write_artifact_manifest(output_dir: str | Path, parsed: ParsedDocument) -> dict[str, Any]: + manifest = build_artifact_manifest(output_dir, parsed) + write_json(Path(output_dir) / MANIFEST_FILENAME, manifest) + return manifest + + +def build_artifact_manifest(output_dir: str | Path, parsed: ParsedDocument) -> dict[str, Any]: + out = Path(output_dir) + files = [_file_record(path, out) for path in _artifact_files(out)] + return { + "schema_version": MANIFEST_SCHEMA_VERSION, + "parsed_document_schema_version": getattr(parsed, "schema_version", SCHEMA_VERSION), + "generated_at_utc": datetime.now(timezone.utc).isoformat(), + "doc_id": parsed.doc_id, + "source_path": parsed.source_path, + "file_type": parsed.file_type, + "quality_score": parsed.quality_report.score, + "counts": { + "pages": len(parsed.pages), + "elements": len(parsed.elements), + "tables": len(parsed.tables), + "figures": len(parsed.figures), + "chunks": len(parsed.chunks), + "gpu_tasks": len(parsed.provenance.get("gpu_tasks", [])), + }, + "artifact_count": len(files), + "files": files, + } + + +def validate_artifact_manifest(output_dir: str | Path) -> dict[str, Any]: + out = Path(output_dir) + manifest_path = out / MANIFEST_FILENAME + if not manifest_path.exists(): + return { + "valid": False, + "manifest_path": str(manifest_path), + "errors": [f"Missing {MANIFEST_FILENAME}"], + "checked_count": 0, + } + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + errors: list[str] = [] + checked = 0 + for record in manifest.get("files", []): + relative_path = str(record.get("path", "")) + path = out / relative_path + if not path.exists(): + errors.append(f"Missing artifact: {relative_path}") + continue + checked += 1 + expected_size = int(record.get("size_bytes", -1)) + actual_size = path.stat().st_size + if actual_size != expected_size: + errors.append(f"Size mismatch: {relative_path}") + expected_sha = str(record.get("sha256", "")) + actual_sha = sha256_file(path) + if actual_sha != expected_sha: + errors.append(f"SHA-256 mismatch: {relative_path}") + + return { + "valid": not errors, + "manifest_path": str(manifest_path), + "doc_id": manifest.get("doc_id"), + "artifact_count": int(manifest.get("artifact_count", 0)), + "checked_count": checked, + "manifest_schema_version": manifest.get("schema_version"), + "parsed_document_schema_version": manifest.get("parsed_document_schema_version"), + "errors": errors, + } + + +def sha256_file(path: str | Path) -> str: + digest = hashlib.sha256() + with Path(path).open("rb") as handle: + for chunk in iter(lambda: handle.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def _artifact_files(output_dir: Path) -> list[Path]: + return sorted( + path + for path in output_dir.rglob("*") + if path.is_file() and path.name != MANIFEST_FILENAME + ) + + +def _file_record(path: Path, output_dir: Path) -> dict[str, Any]: + stat = path.stat() + return { + "path": path.relative_to(output_dir).as_posix(), + "size_bytes": stat.st_size, + "sha256": sha256_file(path), + } diff --git a/zsgdp/benchmarks/__init__.py b/zsgdp/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68944ecd90f22829d002a292fce654e232e5db49 --- /dev/null +++ b/zsgdp/benchmarks/__init__.py @@ -0,0 +1,3 @@ +"""Benchmark scaffolding.""" + +__all__: list[str] = [] diff --git a/zsgdp/benchmarks/ablation_runner.py b/zsgdp/benchmarks/ablation_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..c5dac81d81499c8a192fca0549a318e4f5f00c2d --- /dev/null +++ b/zsgdp/benchmarks/ablation_runner.py @@ -0,0 +1,133 @@ +"""Per-parser ablation runner. + +Definitions (pinned): + +- An "ablation arm" is one full benchmark run with a fixed parser selection. + We run one arm per parser plus a "merged" arm that uses all enabled parsers. +- Per-arm summary is the same shape as run_parser_benchmark's summary; the + comparative CSV picks a flat subset of metrics so arms can be eyeballed + side-by-side. +- The runner does NOT re-define metrics. It calls run_parser_benchmark with + a different `selected_parsers` argument per arm and collates the results. + The expensive work is the parses themselves; metrics are computed by the + same code paths so arm-vs-arm comparisons stay consistent. +""" + +from __future__ import annotations + +import csv +from pathlib import Path +from typing import Any, Sequence + +from zsgdp.benchmarks.parser_quality import run_parser_benchmark +from zsgdp.utils import write_json + +ABLATION_METRIC_KEYS = ( + "document_count", + "mean_quality_score", + "mean_parser_disagreement_rate", + "mean_repair_resolution_rate", + "mean_repair_regression_rate", + "mean_layout_f1", + "mean_layout_class_agnostic_f1", + "layout_evaluated_count", + "mean_table_structure_score", + "mean_table_match_rate", + "table_structure_evaluated_count", + "mean_formula_cer", + "mean_formula_accuracy", + "formula_evaluated_count", + "mean_retrieval_recall_at_1", + "mean_retrieval_recall_at_5", + "mean_retrieval_mrr", + "retrieval_evaluated_count", +) + + +def run_parser_ablations( + input_dir: str | Path, + output_dir: str | Path, + *, + parsers: Sequence[str], + config_path: str | Path | None = None, + dataset_name: str = "custom_folder", + include_merged: bool = True, +) -> dict[str, Any]: + if not parsers: + raise ValueError("run_parser_ablations requires at least one parser.") + + output_root = Path(output_dir) + output_root.mkdir(parents=True, exist_ok=True) + arms: list[dict[str, Any]] = [] + + for parser_name in parsers: + arm_root = output_root / f"arm_{parser_name}" + summary = run_parser_benchmark( + input_dir, + arm_root, + config_path=config_path, + selected_parsers=[parser_name], + dataset_name=dataset_name, + ) + arms.append({"arm": parser_name, "selected_parsers": [parser_name], "output": str(arm_root), "summary": summary}) + + if include_merged and len(parsers) > 1: + merged_root = output_root / "arm_merged" + merged_summary = run_parser_benchmark( + input_dir, + merged_root, + config_path=config_path, + selected_parsers=list(parsers), + dataset_name=dataset_name, + ) + arms.append( + { + "arm": "merged", + "selected_parsers": list(parsers), + "output": str(merged_root), + "summary": merged_summary, + } + ) + + comparison = _build_comparison(arms) + write_json(output_root / "ablation_summary.json", comparison) + _write_comparison_csv(output_root / "ablation_comparison.csv", arms) + return comparison + + +def _build_comparison(arms: list[dict[str, Any]]) -> dict[str, Any]: + rows = [] + for arm in arms: + summary = arm["summary"] + row = { + "arm": arm["arm"], + "selected_parsers": list(arm["selected_parsers"]), + "output": arm["output"], + } + for key in ABLATION_METRIC_KEYS: + value = summary.get(key) + if value is None: + continue + row[key] = value + rows.append(row) + return { + "arm_count": len(rows), + "metric_keys": list(ABLATION_METRIC_KEYS), + "rows": rows, + } + + +def _write_comparison_csv(path: Path, arms: list[dict[str, Any]]) -> None: + fieldnames = ["arm", "selected_parsers", *ABLATION_METRIC_KEYS] + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for arm in arms: + summary = arm["summary"] + row: dict[str, Any] = { + "arm": arm["arm"], + "selected_parsers": "|".join(arm["selected_parsers"]), + } + for key in ABLATION_METRIC_KEYS: + row[key] = summary.get(key, "") + writer.writerow(row) diff --git a/zsgdp/benchmarks/ablations.py b/zsgdp/benchmarks/ablations.py new file mode 100644 index 0000000000000000000000000000000000000000..7459e83aeceb93b44aff0dc04a95bc784c867cbf --- /dev/null +++ b/zsgdp/benchmarks/ablations.py @@ -0,0 +1,94 @@ +"""Ablation names from the project spec.""" + +from __future__ import annotations + +ABLATIONS = [ + "docling_only", + "marker_only", + "mineru_only", + "olmocr_only_on_scans", + "router_without_repair", + "router_with_repair", + "router_with_table_natural_language_rendering", + "router_with_figure_descriptions", + "full_agentic_loop", +] + +CHUNKING_ABLATIONS = [ + "fixed_token_chunks", + "recursive_structure_chunks", + "page_level_chunks", + "parent_child_chunks", + "semantic_chunks", + "late_chunks", + "vision_guided_chunks", + "agentic_proposition_chunks", +] + + +def ablation_plan(config: dict, observed_chunk_strategies: set[str]) -> list[dict]: + parsers = config.get("parsers", {}) + chunking = config.get("chunking", {}) + return [ + *_parser_ablation_records(parsers), + { + "ablation": "router_with_repair", + "category": "pipeline", + "status": "available" if config.get("repair", {}).get("enabled", True) else "disabled", + "notes": "Run the default pipeline with repair enabled.", + }, + { + "ablation": "router_without_repair", + "category": "pipeline", + "status": "available", + "notes": "Run with config override repair.enabled=false.", + }, + { + "ablation": "full_agentic_loop", + "category": "pipeline", + "status": "available", + "notes": "Profiler, router, parser merge, repair, verification, chunking, and GPU task planning.", + }, + *_chunking_ablation_records(chunking, observed_chunk_strategies), + ] + + +def _parser_ablation_records(parsers: dict) -> list[dict]: + records = [] + for parser_name in ["docling", "marker", "mineru", "olmocr", "paddleocr", "unstructured"]: + enabled = bool(parsers.get(parser_name, {}).get("enabled", False)) + records.append( + { + "ablation": f"{parser_name}_only", + "category": "parser", + "status": "configured" if enabled else "planned", + "notes": f"Run benchmark with --parser {parser_name}.", + } + ) + return records + + +def _chunking_ablation_records(chunking: dict, observed_chunk_strategies: set[str]) -> list[dict]: + strategy_to_config = { + "fixed_token_chunks": ("fixed_token_baseline", True), + "recursive_structure_chunks": ("recursive_structure", True), + "page_level_chunks": ("page_level", chunking.get("page_level_for_paginated_docs", True)), + "parent_child_chunks": ("parent_child", chunking.get("parent_child", True)), + "semantic_chunks": ("semantic", chunking.get("semantic_chunking", False)), + "late_chunks": ("late", chunking.get("late_chunking", False)), + "vision_guided_chunks": ("vision_guided", chunking.get("vision_guided", False)), + "agentic_proposition_chunks": ("agentic_proposition", chunking.get("agentic_proposition_chunking", False)), + } + records = [] + for ablation, (strategy, enabled) in strategy_to_config.items(): + observed = strategy in observed_chunk_strategies + status = "observed" if observed else "configured" if enabled else "planned" + records.append( + { + "ablation": ablation, + "category": "chunking", + "status": status, + "notes": f"Strategy key: {strategy}.", + } + ) + return records diff --git a/zsgdp/benchmarks/chunking_quality.py b/zsgdp/benchmarks/chunking_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..c6e1fb33da1ee7ac2e6a0c904a0c9642ab0931de --- /dev/null +++ b/zsgdp/benchmarks/chunking_quality.py @@ -0,0 +1,108 @@ +"""Chunking-quality benchmark helpers.""" + +from __future__ import annotations + +from statistics import mean + +from zsgdp.schema import ParsedDocument + + +def score_chunking_quality(results: list[dict]) -> dict[str, float]: + if not results: + return { + "boundary_precision": 0.0, + "parent_child_resolution": 0.0, + "provenance_completeness": 0.0, + "retrieval_readiness": 0.0, + } + return { + "boundary_precision": sum(float(item.get("boundary_precision", 0.0)) for item in results) / len(results), + "parent_child_resolution": sum(float(item.get("parent_child_resolution", 0.0)) for item in results) / len(results), + "provenance_completeness": sum(float(item.get("provenance_completeness", 0.0)) for item in results) / len(results), + "retrieval_readiness": sum(float(item.get("retrieval_readiness", 0.0)) for item in results) / len(results), + } + + +def chunking_quality_record(parsed: ParsedDocument, source_path: str) -> dict: + """Build deterministic chunk/retrieval-readiness proxy metrics.""" + + chunks = parsed.chunks + if not chunks: + return { + "source_path": source_path, + "doc_id": parsed.doc_id, + "chunk_count": 0, + "boundary_precision": 0.0, + "parent_child_resolution": 0.0, + "provenance_completeness": 0.0, + "retrieval_readiness": 0.0, + "table_chunk_coverage": 0.0, + "figure_chunk_coverage": 0.0, + "avg_tokens": 0.0, + "max_tokens": 0, + } + + chunk_ids = {chunk.chunk_id for chunk in chunks} + parent_ids = {chunk.chunk_id for chunk in chunks if chunk.content_type == "parent"} + child_chunks = [chunk for chunk in chunks if chunk.parent_chunk_id] + parent_chunks = [chunk for chunk in chunks if chunk.content_type == "parent"] + valid_child_links = sum(1 for chunk in child_chunks if chunk.parent_chunk_id in parent_ids) + valid_parent_links = sum( + 1 + for chunk in parent_chunks + if not chunk.child_chunk_ids or all(child_id in chunk_ids for child_id in chunk.child_chunk_ids) + ) + parent_child_denominator = len(child_chunks) + len(parent_chunks) + parent_child_resolution = ( + (valid_child_links + valid_parent_links) / parent_child_denominator + if parent_child_denominator + else 1.0 + ) + + boundary_ready = sum( + 1 + for chunk in chunks + if chunk.boundary_reason + and chunk.text.strip() + and chunk.page_start > 0 + and chunk.page_end >= chunk.page_start + ) + provenance_ready = sum( + 1 + for chunk in chunks + if chunk.doc_id + and chunk.source_parser + and chunk.source_parser != "unknown" + and chunk.page_start > 0 + and chunk.page_end >= chunk.page_start + ) + + metrics = parsed.quality_report.metrics + table_coverage = float(metrics.get("table_chunk_coverage", 1.0)) + figure_coverage = float(metrics.get("figure_chunk_coverage", 1.0)) + boundary_precision = boundary_ready / len(chunks) + provenance_completeness = provenance_ready / len(chunks) + retrieval_readiness = mean( + [ + boundary_precision, + parent_child_resolution, + provenance_completeness, + table_coverage, + figure_coverage, + ] + ) + token_counts = [chunk.token_count for chunk in chunks] + + return { + "source_path": source_path, + "doc_id": parsed.doc_id, + "chunk_count": len(chunks), + "boundary_precision": boundary_precision, + "parent_child_resolution": parent_child_resolution, + "provenance_completeness": provenance_completeness, + "retrieval_readiness": retrieval_readiness, + "table_chunk_coverage": table_coverage, + "figure_chunk_coverage": figure_coverage, + "avg_tokens": mean(token_counts) if token_counts else 0.0, + "max_tokens": max(token_counts) if token_counts else 0, + } diff --git a/zsgdp/benchmarks/cross_dataset.py b/zsgdp/benchmarks/cross_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a27bdedaa7b755c85b9e1ee018f6f522f289f0 --- /dev/null +++ b/zsgdp/benchmarks/cross_dataset.py @@ -0,0 +1,169 @@ +"""Combine multiple benchmark summaries into a cross-dataset comparison. + +Definitions (pinned): + +- Inputs are previously emitted results.json summaries (or directories that + contain one). Each summary is treated as one column in the comparison. +- Two views are produced: + - dataset_summary: one row per dataset/run with the headline document-level + metric means (quality, layout F1, table score, formula CER, recall@5, + disagreement rate, repair resolution rate). Direct apples-to-apples. + - parser_matrix: one row per (parser, dataset) cell built by joining each + summary's per_parser_gt_leaderboard. Cells are missing when the parser + didn't appear in that run; readers can spot coverage gaps directly. +- The combiner does not re-derive metrics; it reads the summary fields + computed by run_parser_benchmark. If a summary is missing a metric (older + run, different code version), that cell is None rather than 0.0 so the + caller can distinguish "missing" from "true zero." +""" + +from __future__ import annotations + +import csv +import json +from pathlib import Path +from typing import Any + +from zsgdp.utils import write_json + +DATASET_HEADLINE_METRICS = ( + "document_count", + "mean_quality_score", + "mean_parser_disagreement_rate", + "mean_repair_resolution_rate", + "mean_repair_regression_rate", + "mean_layout_f1", + "mean_layout_class_agnostic_f1", + "layout_evaluated_count", + "mean_table_structure_score", + "mean_table_match_rate", + "table_structure_evaluated_count", + "mean_formula_cer", + "mean_formula_accuracy", + "formula_evaluated_count", + "mean_retrieval_recall_at_1", + "mean_retrieval_recall_at_5", + "mean_retrieval_mrr", + "retrieval_evaluated_count", +) + +PARSER_MATRIX_METRICS = ( + "document_count", + "layout_evaluated_count", + "mean_layout_class_aware_f1", + "mean_layout_class_agnostic_f1", + "table_evaluated_count", + "mean_table_structure_score", + "mean_table_cell_content_f1", + "formula_evaluated_count", + "mean_formula_cer", + "mean_formula_accuracy", +) + + +def combine_benchmark_summaries( + inputs: list[tuple[str, dict[str, Any] | str | Path]], +) -> dict[str, Any]: + """Combine summaries provided as (label, summary_dict_or_path) pairs. + + A summary_dict_or_path may be a dict (already loaded), a file path to a + results.json, or a directory that contains results.json. Labels are used + as the dataset column key in the comparison output; duplicates raise. + """ + + labels_seen: set[str] = set() + runs: list[dict[str, Any]] = [] + for label, source in inputs: + if label in labels_seen: + raise ValueError(f"Duplicate label in cross-dataset comparison: {label}") + labels_seen.add(label) + summary = _load_summary(source) + runs.append({"label": label, "source": str(source) if not isinstance(source, dict) else "", "summary": summary}) + + dataset_rows = [_dataset_summary_row(run) for run in runs] + parser_matrix = _parser_matrix(runs) + return { + "run_count": len(runs), + "labels": [run["label"] for run in runs], + "dataset_summary": dataset_rows, + "parser_matrix": parser_matrix, + } + + +def write_cross_dataset_outputs(comparison: dict[str, Any], output_dir: str | Path) -> None: + output_root = Path(output_dir) + output_root.mkdir(parents=True, exist_ok=True) + write_json(output_root / "cross_dataset_comparison.json", comparison) + _write_dataset_summary_csv(output_root / "dataset_summary.csv", comparison["dataset_summary"]) + _write_parser_matrix_csv(output_root / "parser_matrix.csv", comparison["parser_matrix"], comparison["labels"]) + + +def _load_summary(source: dict[str, Any] | str | Path) -> dict[str, Any]: + if isinstance(source, dict): + return source + path = Path(source) + if path.is_dir(): + path = path / "results.json" + if not path.exists(): + raise FileNotFoundError(f"Benchmark summary not found: {path}") + return json.loads(path.read_text(encoding="utf-8")) + + +def _dataset_summary_row(run: dict[str, Any]) -> dict[str, Any]: + summary = run["summary"] + row = { + "label": run["label"], + "dataset_name": summary.get("dataset_name"), + "dataset_root": summary.get("dataset_root"), + } + for key in DATASET_HEADLINE_METRICS: + row[key] = summary.get(key) + return row + + +def _parser_matrix(runs: list[dict[str, Any]]) -> list[dict[str, Any]]: + parser_to_label_to_row: dict[str, dict[str, dict[str, Any]]] = {} + for run in runs: + leaderboard = run["summary"].get("per_parser_gt_leaderboard") or [] + for entry in leaderboard: + parser = entry.get("parser") + if not parser: + continue + parser_to_label_to_row.setdefault(parser, {})[run["label"]] = entry + + matrix: list[dict[str, Any]] = [] + for parser in sorted(parser_to_label_to_row): + row: dict[str, Any] = {"parser": parser} + for run in runs: + entry = parser_to_label_to_row[parser].get(run["label"]) + if entry is None: + for metric in PARSER_MATRIX_METRICS: + row[f"{run['label']}__{metric}"] = None + else: + for metric in PARSER_MATRIX_METRICS: + row[f"{run['label']}__{metric}"] = entry.get(metric) + matrix.append(row) + return matrix + + +def _write_dataset_summary_csv(path: Path, rows: list[dict[str, Any]]) -> None: + fieldnames = ["label", "dataset_name", "dataset_root", *DATASET_HEADLINE_METRICS] + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow({field: row.get(field, "") for field in fieldnames}) + + +def _write_parser_matrix_csv(path: Path, rows: list[dict[str, Any]], labels: list[str]) -> None: + fieldnames = ["parser"] + for label in labels: + for metric in PARSER_MATRIX_METRICS: + fieldnames.append(f"{label}__{metric}") + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + normalized = {field: row.get(field, "") for field in fieldnames} + normalized["parser"] = row.get("parser", "") + writer.writerow(normalized) diff --git a/zsgdp/benchmarks/datasets.py b/zsgdp/benchmarks/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bc514c12ab632ce443c74e814b1c4702a326ba --- /dev/null +++ b/zsgdp/benchmarks/datasets.py @@ -0,0 +1,226 @@ +"""Dataset loaders for benchmarking. + +Loader contract (pinned so downstream metrics can rely on it): + +- Each loader is a callable that accepts a dataset root Path and yields + DatasetDocument records. +- DatasetDocument.path points to the file the parser will be asked to parse. +- DatasetDocument.ground_truth is a dict in a loader-specific shape, or None + when the dataset has no labels (treated as "parse-only" benchmarking). +- Loaders never download data. They expect a local checkout under the given + root. They raise FileNotFoundError when the root is missing or empty. +- Loaders may filter to supported file types (e.g. PDFs only) and skip files + they cannot interpret. + +Three loaders ship by default: + +- custom_folder: walks every file in the root, no ground truth. Backward- + compatible behaviour for existing `zsgdp benchmark --input ` runs. +- omnidocbench: pairs each PDF with a sibling .json ground-truth file + (see https://github.com/opendatalab/OmniDocBench for the documented layout). +- doclaynet: reads a COCO-style annotations file (annotations.json or + annotations/*.json) and yields one DatasetDocument per image, with the + matched annotation list attached as ground_truth. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import json +from pathlib import Path +from typing import Any, Callable, Iterator + +from zsgdp.utils import to_plain_data + +DatasetLoader = Callable[[Path], Iterator["DatasetDocument"]] + + +@dataclass(slots=True) +class DatasetDocument: + dataset_id: str + doc_id: str + path: Path + ground_truth: dict[str, Any] | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + data = to_plain_data(self) + data["path"] = str(self.path) + return data + + +_LOADERS: dict[str, DatasetLoader] = {} + + +def register_dataset_loader(name: str, loader: DatasetLoader) -> None: + _LOADERS[name] = loader + + +def get_dataset_loader(name: str) -> DatasetLoader: + canonical = _canonical_name(name) + if canonical not in _LOADERS: + raise KeyError(f"Unknown dataset loader: {name!r}. Known: {sorted(_LOADERS)}") + return _LOADERS[canonical] + + +def list_dataset_loaders() -> list[str]: + return sorted(_LOADERS) + + +def iter_dataset(name: str, root: str | Path) -> Iterator[DatasetDocument]: + loader = get_dataset_loader(name) + yield from loader(Path(root)) + + +def _canonical_name(name: str) -> str: + if name in {"custom", "folder", "default"}: + return "custom_folder" + return name + + +def list_documents(dataset_path: str | Path) -> list[Path]: + """Backward-compatible filesystem walk used before loaders existed.""" + + root = Path(dataset_path) + return sorted(path for path in root.rglob("*") if path.is_file()) + + +def _load_custom_folder(root: Path) -> Iterator[DatasetDocument]: + if not root.exists(): + raise FileNotFoundError(f"Dataset root does not exist: {root}") + if not root.is_dir(): + raise NotADirectoryError(f"Dataset root must be a directory: {root}") + for path in sorted(item for item in root.iterdir() if item.is_file()): + yield DatasetDocument( + dataset_id="custom_folder", + doc_id=path.stem or path.name, + path=path, + ground_truth=None, + metadata={"source": "filesystem"}, + ) + + +def _load_omnidocbench(root: Path) -> Iterator[DatasetDocument]: + """Yield (pdf, ground_truth.json) pairs from an OmniDocBench checkout. + + Convention: for every .pdf under root, look for .json next to + it (any directory). Documents without a matching JSON are skipped with a + metadata note rather than failing the whole loader. + """ + + if not root.exists(): + raise FileNotFoundError(f"OmniDocBench root does not exist: {root}") + pdfs = sorted(path for path in root.rglob("*.pdf")) + if not pdfs: + raise FileNotFoundError(f"No PDFs found under OmniDocBench root: {root}") + for pdf in pdfs: + gt_path = _find_sibling_json(pdf) + ground_truth: dict[str, Any] | None = None + if gt_path is not None: + try: + ground_truth = json.loads(gt_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError) as exc: + ground_truth = {"_load_error": str(exc), "_path": str(gt_path)} + yield DatasetDocument( + dataset_id="omnidocbench", + doc_id=pdf.stem, + path=pdf, + ground_truth=ground_truth, + metadata={ + "ground_truth_path": str(gt_path) if gt_path else None, + "has_ground_truth": ground_truth is not None and "_load_error" not in (ground_truth or {}), + }, + ) + + +def _load_doclaynet(root: Path) -> Iterator[DatasetDocument]: + """Yield image documents matched against COCO-style DocLayNet annotations. + + Looks for annotations.json at the root, then under annotations/. Builds + image-id -> annotations index once, then yields one DatasetDocument per + image referenced in the COCO file. Image paths are resolved relative to + the annotation file's parent and to root. + """ + + if not root.exists(): + raise FileNotFoundError(f"DocLayNet root does not exist: {root}") + annotations_path = _find_doclaynet_annotations(root) + if annotations_path is None: + raise FileNotFoundError( + f"No DocLayNet annotations.json found under {root}. Expected annotations.json at the root or under annotations/." + ) + coco = json.loads(annotations_path.read_text(encoding="utf-8")) + images = coco.get("images") or [] + annotations = coco.get("annotations") or [] + categories = {int(category["id"]): category for category in coco.get("categories") or [] if "id" in category} + annotations_by_image: dict[int, list[dict[str, Any]]] = {} + for annotation in annotations: + image_id = annotation.get("image_id") + if image_id is None: + continue + annotations_by_image.setdefault(int(image_id), []).append(annotation) + + image_root = annotations_path.parent + for image in images: + image_id = image.get("id") + file_name = image.get("file_name") + if image_id is None or not file_name: + continue + path = _resolve_doclaynet_image_path(file_name, image_root, root) + if path is None: + continue + yield DatasetDocument( + dataset_id="doclaynet", + doc_id=str(image.get("file_name") or image_id), + path=path, + ground_truth={ + "annotations": annotations_by_image.get(int(image_id), []), + "image": image, + "categories": categories, + }, + metadata={ + "annotations_path": str(annotations_path), + "image_id": image_id, + "annotation_count": len(annotations_by_image.get(int(image_id), [])), + }, + ) + + +def _find_sibling_json(pdf: Path) -> Path | None: + candidate = pdf.with_suffix(".json") + if candidate.exists(): + return candidate + for sibling in pdf.parent.glob(f"{pdf.stem}*.json"): + return sibling + return None + + +def _find_doclaynet_annotations(root: Path) -> Path | None: + direct = root / "annotations.json" + if direct.exists(): + return direct + annotations_dir = root / "annotations" + if annotations_dir.is_dir(): + for path in sorted(annotations_dir.glob("*.json")): + return path + for path in sorted(root.glob("*.json")): + return path + return None + + +def _resolve_doclaynet_image_path(file_name: str, image_root: Path, dataset_root: Path) -> Path | None: + candidates = [ + image_root / file_name, + dataset_root / file_name, + dataset_root / "images" / file_name, + dataset_root / "PNG" / file_name, + ] + for candidate in candidates: + if candidate.exists(): + return candidate + return None + + +register_dataset_loader("custom_folder", _load_custom_folder) +register_dataset_loader("omnidocbench", _load_omnidocbench) +register_dataset_loader("doclaynet", _load_doclaynet) diff --git a/zsgdp/benchmarks/embedding_retriever.py b/zsgdp/benchmarks/embedding_retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..abffc827bb94323c88488b97a404b7e656f6d318 --- /dev/null +++ b/zsgdp/benchmarks/embedding_retriever.py @@ -0,0 +1,146 @@ +"""Embedding-based retriever for the retrieval benchmark. + +Definitions and contract (pinned): + +- EmbeddingRetriever satisfies the Retriever Protocol from + zsgdp.benchmarks.retrieval. It is opt-in: the benchmark default stays + LexicalRetriever (model-free) so CI and local runs do not pay model + download cost unintentionally. +- The embedder is a callable: list[str] -> list[list[float]]. It must + return L2-normalized vectors so cosine similarity reduces to a dot + product. Two construction paths: + - Pass `embedder=...` directly (used by tests and any caller that wants + full control over batching, device placement, or remote inference). + - Pass `model_id=...` and let the retriever lazy-load + sentence-transformers. Selected through `build_retriever` from config + by setting `benchmarks.retriever.backend = "embedding"`. +- Index and query both call the embedder. The retriever is stateless + beyond the indexed chunk vectors; reusing across documents requires a + fresh `index()` call, same contract as LexicalRetriever. +- The default model_id (`jinaai/jina-embeddings-v3`) and task + (`retrieval.passage`) match the project spec / configs/default.yaml. + Other sentence-transformers models work as long as they accept the same + encode signature; jina's task-prompt argument is optional and silently + ignored by models that don't recognize it. +""" + +from __future__ import annotations + +from typing import Any, Callable, Sequence + +from zsgdp.schema import Chunk + +Embedder = Callable[[list[str]], list[list[float]]] + + +class EmbeddingRetriever: + def __init__( + self, + *, + model_id: str = "jinaai/jina-embeddings-v3", + task: str | None = "retrieval.passage", + query_task: str | None = None, + embedder: Embedder | None = None, + ) -> None: + self._model_id = model_id + self._task = task + self._query_task = query_task or task + self._explicit_embedder = embedder + self._embedder: Embedder | None = embedder + self._chunk_ids: list[str] = [] + self._vectors: list[list[float]] = [] + + def index(self, chunks: Sequence[Chunk]) -> None: + embedder = self._ensure_embedder() + texts = [chunk.text for chunk in chunks] + if not texts: + self._chunk_ids = [] + self._vectors = [] + return + vectors = embedder(texts) + if len(vectors) != len(texts): + raise RuntimeError( + f"EmbeddingRetriever embedder returned {len(vectors)} vectors for {len(texts)} chunks." + ) + self._chunk_ids = [chunk.chunk_id for chunk in chunks] + self._vectors = [_normalize(list(vector)) for vector in vectors] + + def query(self, text: str, *, top_k: int) -> list[str]: + if not self._vectors: + return [] + embedder = self._ensure_embedder() + query_vec = embedder([text]) + if not query_vec: + return [] + query_vector = _normalize(list(query_vec[0])) + if not query_vector: + return [] + scored: list[tuple[float, int]] = [] + for index, vector in enumerate(self._vectors): + score = _dot(query_vector, vector) + if score > 0: + scored.append((score, index)) + scored.sort(key=lambda item: (-item[0], item[1])) + return [self._chunk_ids[index] for _score, index in scored[:top_k]] + + def _ensure_embedder(self) -> Embedder: + if self._embedder is not None: + return self._embedder + try: + from sentence_transformers import SentenceTransformer # type: ignore + except ImportError as exc: + raise RuntimeError( + "EmbeddingRetriever requires sentence-transformers. " + "Install with `pip install sentence-transformers` or pass `embedder=...` explicitly." + ) from exc + + model = SentenceTransformer(self._model_id, trust_remote_code=True) + + def encode(texts: list[str]) -> list[list[float]]: + kwargs: dict[str, Any] = {"normalize_embeddings": True} + if self._task: + kwargs["task"] = self._task + try: + vectors = model.encode(texts, **kwargs) + except TypeError: + # Fallback for models that don't accept the `task` kwarg. + kwargs.pop("task", None) + vectors = model.encode(texts, **kwargs) + return [list(map(float, vector)) for vector in vectors] + + self._embedder = encode + return encode + + +def build_retriever(config: dict[str, Any] | None) -> Any: + """Pick a retriever based on benchmarks.retriever config; default is lexical.""" + + from zsgdp.benchmarks.retrieval import LexicalRetriever + + config = config or {} + benchmarks = config.get("benchmarks", {}) if isinstance(config.get("benchmarks", {}), dict) else {} + retriever_cfg = benchmarks.get("retriever", {}) if isinstance(benchmarks.get("retriever", {}), dict) else {} + backend = str(retriever_cfg.get("backend", "lexical")).strip().lower() + + if backend in {"", "lexical", "tfidf", "tf-idf"}: + return LexicalRetriever() + if backend in {"embedding", "embeddings", "jina", "sentence-transformers"}: + embedding_config = config.get("gpu", {}).get("models", {}).get("embedding", {}) if isinstance(config.get("gpu", {}), dict) else {} + model_id = retriever_cfg.get("model_id") or embedding_config.get("model_id") or "jinaai/jina-embeddings-v3" + task = retriever_cfg.get("task") or embedding_config.get("task") or "retrieval.passage" + query_task = retriever_cfg.get("query_task") + return EmbeddingRetriever(model_id=model_id, task=task, query_task=query_task) + raise ValueError(f"Unknown retriever backend: {backend!r}. Use 'lexical' or 'embedding'.") + + +def _normalize(vector: list[float]) -> list[float]: + norm = sum(value * value for value in vector) ** 0.5 + if norm == 0: + return [] + return [value / norm for value in vector] + + +def _dot(a: list[float], b: list[float]) -> float: + if len(a) != len(b): + return 0.0 + return sum(x * y for x, y in zip(a, b)) diff --git a/zsgdp/benchmarks/ground_truth.py b/zsgdp/benchmarks/ground_truth.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ff562688ba1c38689afe061f5f8ce8d6856d6e --- /dev/null +++ b/zsgdp/benchmarks/ground_truth.py @@ -0,0 +1,391 @@ +"""Adapters from dataset ground truth + parsed output to LayoutItem records. + +Each adapter returns a list of dicts with keys (bbox, category, page_num) +matching the LayoutItem contract in zsgdp.verify.layout_f1. + +- doclaynet_layout_truths converts COCO `[x, y, w, h]` boxes to xyxy and + normalizes category names (Title -> title, Section-header -> heading, etc.) + into the same vocabulary parsed_layout_predictions emits. +- omnidocbench_layout_truths is best-effort: it inspects a small number of + known shapes (`reading_order`, `elements`, `layout_dets`) and pulls bbox + + category from whichever it finds. Returns an empty list when the JSON + doesn't match any known shape, so the caller can detect "ground truth + present but unparseable" via metadata rather than crashing. +- parsed_layout_predictions extracts boxes from parsed.elements, plus tables + and figures that carry bboxes but may not appear in elements (e.g. table + objects that came from a layout-aware parser). Element types are mapped + through the same canonical category vocabulary as truths. +""" + +from __future__ import annotations + +from typing import Any + +from zsgdp.schema import ParsedDocument + +CANONICAL_CATEGORIES = { + "title": "title", + "heading": "heading", + "section-header": "heading", + "section_header": "heading", + "subtitle": "heading", + "text": "paragraph", + "paragraph": "paragraph", + "body": "paragraph", + "list-item": "list_item", + "list_item": "list_item", + "list": "list_item", + "table": "table", + "figure": "figure", + "picture": "figure", + "image": "figure", + "chart": "figure", + "formula": "formula", + "equation": "formula", + "caption": "caption", + "footnote": "footnote", + "page-header": "header", + "page_header": "header", + "header": "header", + "page-footer": "footer", + "page_footer": "footer", + "footer": "footer", + "page-number": "footer", + "page_number": "footer", +} + + +def canonical_category(name: Any) -> str: + return CANONICAL_CATEGORIES.get(str(name).strip().lower(), str(name).strip().lower()) + + +def doclaynet_layout_truths(ground_truth: dict[str, Any] | None) -> list[dict[str, Any]]: + if not isinstance(ground_truth, dict): + return [] + image = ground_truth.get("image") or {} + annotations = ground_truth.get("annotations") or [] + categories = ground_truth.get("categories") or {} + if isinstance(categories, list): + categories = {int(item["id"]): item for item in categories if isinstance(item, dict) and "id" in item} + page_num = _doclaynet_page_num(image) + + truths: list[dict[str, Any]] = [] + for annotation in annotations: + if not isinstance(annotation, dict): + continue + bbox = _xywh_to_xyxy(annotation.get("bbox")) + if bbox is None: + continue + category = _doclaynet_category_name(annotation, categories) + truths.append( + { + "bbox": bbox, + "category": canonical_category(category), + "page_num": page_num, + } + ) + return truths + + +def omnidocbench_layout_truths(ground_truth: dict[str, Any] | None) -> list[dict[str, Any]]: + if not isinstance(ground_truth, dict): + return [] + truths: list[dict[str, Any]] = [] + for record in _omnidocbench_records(ground_truth): + bbox = _coerce_bbox(record.get("bbox") or record.get("poly") or record.get("box")) + if bbox is None: + continue + category = record.get("category") or record.get("type") or record.get("label") or record.get("class") + page_num = record.get("page_num") or record.get("page") or record.get("page_idx") + truths.append( + { + "bbox": bbox, + "category": canonical_category(category or ""), + "page_num": int(page_num) if isinstance(page_num, (int, float)) else None, + } + ) + return truths + + +def parsed_layout_predictions(parsed: ParsedDocument) -> list[dict[str, Any]]: + return layout_predictions_from_items( + elements=getattr(parsed, "elements", []), + tables=getattr(parsed, "tables", []), + figures=getattr(parsed, "figures", []), + ) + + +def layout_predictions_from_items( + *, + elements: Any, + tables: Any, + figures: Any, +) -> list[dict[str, Any]]: + predictions: list[dict[str, Any]] = [] + for element in elements or []: + bbox = _coerce_bbox(_attr_or_key(element, "bbox")) + if bbox is None: + continue + page_num = _attr_or_key(element, "page_num") + predictions.append( + { + "bbox": bbox, + "category": canonical_category(_attr_or_key(element, "type") or ""), + "page_num": int(page_num) if isinstance(page_num, (int, float)) else None, + } + ) + seen_table_keys: set[tuple[int, tuple[float, float, float, float]]] = set() + for table in tables or []: + raw_bbox = _attr_or_key(table, "bbox") + if not raw_bbox: + continue + page_nums = _attr_or_key(table, "page_nums") or [] + first_page = int(page_nums[0]) if page_nums else 0 + bboxes = ( + raw_bbox + if (isinstance(raw_bbox, list) and raw_bbox and isinstance(raw_bbox[0], (list, tuple))) + else [raw_bbox] + ) + for candidate in bboxes: + bbox = _coerce_bbox(candidate) + if bbox is None: + continue + key = (first_page, bbox) + if key in seen_table_keys: + continue + seen_table_keys.add(key) + predictions.append({"bbox": bbox, "category": "table", "page_num": first_page}) + for figure in figures or []: + bbox = _coerce_bbox(_attr_or_key(figure, "bbox")) + if bbox is None: + continue + page_num = _attr_or_key(figure, "page_num") + predictions.append( + { + "bbox": bbox, + "category": "figure", + "page_num": int(page_num) if isinstance(page_num, (int, float)) else None, + } + ) + return predictions + + +def _attr_or_key(obj: Any, name: str) -> Any: + if isinstance(obj, dict): + return obj.get(name) + return getattr(obj, name, None) + + +def _xywh_to_xyxy(raw: Any) -> tuple[float, float, float, float] | None: + if not isinstance(raw, (list, tuple)) or len(raw) < 4: + return None + try: + x, y, w, h = (float(raw[0]), float(raw[1]), float(raw[2]), float(raw[3])) + except (TypeError, ValueError): + return None + if w <= 0 or h <= 0: + return None + return (x, y, x + w, y + h) + + +def _coerce_bbox(raw: Any) -> tuple[float, float, float, float] | None: + if not isinstance(raw, (list, tuple)) or len(raw) < 4: + return None + try: + x0, y0, x1, y1 = (float(raw[0]), float(raw[1]), float(raw[2]), float(raw[3])) + except (TypeError, ValueError): + return None + if x1 <= x0 or y1 <= y0: + return None + return (x0, y0, x1, y1) + + +def _doclaynet_page_num(image: dict[str, Any]) -> int | None: + for key in ("page_no", "page_num", "page", "page_number"): + value = image.get(key) + if isinstance(value, (int, float)): + return int(value) + return 1 + + +def _doclaynet_category_name(annotation: dict[str, Any], categories: dict[Any, Any]) -> str: + if "category_name" in annotation: + return str(annotation["category_name"]) + category_id = annotation.get("category_id") + if category_id is None: + return str(annotation.get("category", "")) + category = categories.get(int(category_id)) if isinstance(category_id, (int, float)) else None + if isinstance(category, dict): + return str(category.get("name", "")) + return "" + + +def omnidocbench_table_truths(ground_truth: dict[str, Any] | None) -> list[dict[str, Any]]: + """Pull table records (markdown/html + page_num) from OmniDocBench JSON.""" + + truths: list[dict[str, Any]] = [] + if not isinstance(ground_truth, dict): + return truths + for record in _omnidocbench_records(ground_truth): + category = canonical_category(record.get("category") or record.get("type") or record.get("label") or "") + if category != "table": + continue + markdown = record.get("markdown") or record.get("table_markdown") + html = record.get("html") or record.get("table_html") + if not markdown and not html: + continue + page_num = record.get("page_num") or record.get("page") or record.get("page_idx") + truths.append( + { + "markdown": markdown, + "html": html, + "page_num": int(page_num) if isinstance(page_num, (int, float)) else None, + "table_id": record.get("table_id") or record.get("id"), + "bbox": _coerce_bbox(record.get("bbox")), + } + ) + return truths + + +def omnidocbench_formula_truths(ground_truth: dict[str, Any] | None) -> list[dict[str, Any]]: + """Pull formula records (latex + page_num) from OmniDocBench JSON.""" + + truths: list[dict[str, Any]] = [] + if not isinstance(ground_truth, dict): + return truths + for record in _omnidocbench_records(ground_truth): + category = canonical_category(record.get("category") or record.get("type") or record.get("label") or "") + if category != "formula": + continue + latex = record.get("latex") or record.get("text") or record.get("markdown") + if not latex or not str(latex).strip(): + continue + page_num = record.get("page_num") or record.get("page") or record.get("page_idx") + truths.append( + { + "latex": str(latex), + "page_num": int(page_num) if isinstance(page_num, (int, float)) else None, + "formula_id": record.get("formula_id") or record.get("id"), + "bbox": _coerce_bbox(record.get("bbox")), + } + ) + return truths + + +def parsed_table_records(parsed) -> list[dict[str, Any]]: + return table_records_from_items( + elements=getattr(parsed, "elements", []), + tables=getattr(parsed, "tables", []), + ) + + +def table_records_from_items( + *, + elements: Any, + tables: Any, +) -> list[dict[str, Any]]: + records: list[dict[str, Any]] = [] + seen_keys: set[tuple] = set() + for table in tables or []: + markdown = _attr_or_key(table, "markdown") + html = _attr_or_key(table, "html") + if not markdown and not html: + continue + page_nums = _attr_or_key(table, "page_nums") or [] + page_num = page_nums[0] if page_nums else None + table_id = _attr_or_key(table, "table_id") + key = ("table", table_id, page_num) + if key in seen_keys: + continue + seen_keys.add(key) + records.append( + { + "markdown": markdown, + "html": html, + "page_num": page_num, + "table_id": table_id, + "bbox": _coerce_first_bbox(_attr_or_key(table, "bbox")), + } + ) + for element in elements or []: + if _attr_or_key(element, "type") != "table": + continue + markdown = _attr_or_key(element, "markdown") + html = _attr_or_key(element, "html") + if not markdown and not html: + continue + element_id = _attr_or_key(element, "element_id") + page_num = _attr_or_key(element, "page_num") + key = ("element", element_id, page_num) + if key in seen_keys: + continue + seen_keys.add(key) + records.append( + { + "markdown": markdown, + "html": html, + "page_num": page_num, + "table_id": element_id, + "bbox": _coerce_bbox(_attr_or_key(element, "bbox")), + } + ) + return records + + +def parsed_formula_records(parsed) -> list[dict[str, Any]]: + return formula_records_from_items(elements=getattr(parsed, "elements", [])) + + +def formula_records_from_items(*, elements: Any) -> list[dict[str, Any]]: + records: list[dict[str, Any]] = [] + for element in elements or []: + if _attr_or_key(element, "type") != "formula": + continue + latex = ( + _attr_or_key(element, "markdown") + or _attr_or_key(element, "text") + or _attr_or_key(element, "html") + ) + if not latex or not str(latex).strip(): + continue + records.append( + { + "latex": str(latex), + "page_num": _attr_or_key(element, "page_num"), + "formula_id": _attr_or_key(element, "element_id"), + "bbox": _coerce_bbox(_attr_or_key(element, "bbox")), + } + ) + return records + + +def _coerce_first_bbox(raw: Any) -> tuple[float, float, float, float] | None: + if isinstance(raw, list) and raw and isinstance(raw[0], (list, tuple)): + return _coerce_bbox(raw[0]) + return _coerce_bbox(raw) + + +def _omnidocbench_records(ground_truth: dict[str, Any]) -> list[dict[str, Any]]: + for key in ("layout_dets", "elements", "annotations", "regions", "blocks"): + value = ground_truth.get(key) + if isinstance(value, list): + return [item for item in value if isinstance(item, dict)] + pages = ground_truth.get("pages") + if isinstance(pages, list): + records: list[dict[str, Any]] = [] + for page_index, page in enumerate(pages, start=1): + if not isinstance(page, dict): + continue + page_num = page.get("page_num") or page.get("page") or page_index + for key in ("layout_dets", "elements", "annotations", "regions"): + value = page.get(key) + if not isinstance(value, list): + continue + for item in value: + if not isinstance(item, dict): + continue + enriched = {**item} + enriched.setdefault("page_num", page_num) + records.append(enriched) + return records + return [] diff --git a/zsgdp/benchmarks/parser_quality.py b/zsgdp/benchmarks/parser_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..bc58509411bf86b6706fa28e5caad1938dfb9adc --- /dev/null +++ b/zsgdp/benchmarks/parser_quality.py @@ -0,0 +1,944 @@ +"""Parser-quality benchmark helpers.""" + +from __future__ import annotations + +import csv +from pathlib import Path +from statistics import mean +from time import perf_counter + +from zsgdp.benchmarks.ablations import ablation_plan +from zsgdp.benchmarks.chunking_quality import chunking_quality_record, score_chunking_quality +from zsgdp.benchmarks.datasets import DatasetDocument, get_dataset_loader +from zsgdp.benchmarks.ground_truth import ( + doclaynet_layout_truths, + omnidocbench_formula_truths, + omnidocbench_layout_truths, + omnidocbench_table_truths, + parsed_formula_records, + parsed_layout_predictions, + parsed_table_records, +) +from zsgdp.benchmarks.per_parser_metrics import compute_per_parser_metrics +from zsgdp.benchmarks.retrieval import run_retrieval_for_document +from zsgdp.benchmarks.structure_quality import score_structure_quality, structure_quality_record +from zsgdp.benchmarks.throughput import summarize_throughput, throughput_record +from zsgdp.config import load_config +from zsgdp.pipeline import parse_document +from zsgdp.utils import write_json +from zsgdp.verify.formula_extraction import compute_formula_extraction +from zsgdp.verify.layout_f1 import compute_layout_f1 +from zsgdp.verify.retrieval import compute_retrieval_metrics +from zsgdp.verify.table_structure import compute_table_structure_score + + +GROUND_TRUTH_ADAPTERS = { + "doclaynet": doclaynet_layout_truths, + "omnidocbench": omnidocbench_layout_truths, +} + +TABLE_TRUTH_ADAPTERS = { + "omnidocbench": omnidocbench_table_truths, +} + +FORMULA_TRUTH_ADAPTERS = { + "omnidocbench": omnidocbench_formula_truths, +} + + +def score_parser_quality(records: list[dict]) -> dict[str, float]: + if not records: + return {"mean_quality_score": 0.0} + return { + "mean_quality_score": sum(float(record.get("quality_score", 0.0)) for record in records) / len(records) + } + + +def run_parser_benchmark( + input_dir: str | Path, + output_dir: str | Path, + *, + config_path: str | Path | None = None, + selected_parsers: list[str] | None = None, + dataset_name: str = "custom_folder", +) -> dict: + input_root = Path(input_dir) + output_root = Path(output_dir) + output_root.mkdir(parents=True, exist_ok=True) + loader = get_dataset_loader(dataset_name) + bench_config = load_config(config_path) + documents: list[dict] = [] + parser_rows: list[dict] = [] + chunk_rows: list[dict] = [] + structure_rows: list[dict] = [] + chunk_quality_rows: list[dict] = [] + throughput_rows: list[dict] = [] + repair_rows: list[dict] = [] + layout_rows: list[dict] = [] + table_structure_rows: list[dict] = [] + formula_rows: list[dict] = [] + retrieval_rows: list[dict] = [] + per_parser_rows: list[dict] = [] + observed_chunk_strategies: set[str] = set() + + for dataset_document in loader(input_root): + path = dataset_document.path + doc_out = output_root / "parsed" / (dataset_document.doc_id or path.stem) + started = perf_counter() + parsed = parse_document(path, doc_out, config_path=config_path, selected_parsers=selected_parsers) + elapsed_seconds = perf_counter() - started + chunk_strategy_counts = _chunk_strategy_counts(parsed.chunks) + observed_chunk_strategies.update(chunk_strategy_counts) + disagreement = parsed.provenance.get("parser_disagreement", {}) or {} + repair_success = parsed.provenance.get("repair_success", {}) or {} + layout_metrics = _layout_metrics(dataset_document, parsed) + table_metrics = _table_structure_metrics(dataset_document, parsed) + formula_metrics = _formula_extraction_metrics(dataset_document, parsed) + retrieval_metrics = _retrieval_metrics(parsed, bench_config) + contribution = _parser_contribution(parsed) + per_parser = _compute_per_parser_block(dataset_document, parsed) + doc_record = { + "source_path": str(path), + "dataset_id": dataset_document.dataset_id, + "dataset_doc_id": dataset_document.doc_id, + "has_ground_truth": dataset_document.ground_truth is not None, + "doc_id": parsed.doc_id, + "file_type": parsed.file_type, + "quality_score": parsed.quality_report.score, + "element_count": len(parsed.elements), + "table_count": len(parsed.tables), + "figure_count": len(parsed.figures), + "chunk_count": len(parsed.chunks), + "chunk_strategy_counts": chunk_strategy_counts, + "chunk_quality_metrics": _chunk_quality_metrics(parsed.quality_report.metrics), + "elapsed_seconds": elapsed_seconds, + "parser_metrics": parsed.provenance.get("parser_metrics", {}), + "parser_failures": parsed.provenance.get("parser_failures", {}), + "parser_disagreement_rate": float(disagreement.get("disagreement_rate", 0.0)), + "candidate_count": int(disagreement.get("candidate_count", 0)), + "conflict_count": int(disagreement.get("conflict_count", 0)), + "repair_resolution_rate": float(repair_success.get("repair_resolution_rate", 1.0)), + "repair_regression_rate": float(repair_success.get("repair_regression_rate", 0.0)), + "repair_iteration_count": int(repair_success.get("iteration_count", 0)), + "layout_f1": layout_metrics["summary"]["class_aware_f1"], + "layout_class_agnostic_f1": layout_metrics["summary"]["class_agnostic_f1"], + "layout_evaluated": layout_metrics["summary"]["evaluated"], + "table_structure_score": table_metrics["summary"]["mean_table_score"], + "table_match_rate": table_metrics["summary"]["table_match_rate"], + "table_structure_evaluated": table_metrics["summary"]["evaluated"], + "formula_cer": formula_metrics["summary"]["mean_cer"], + "formula_accuracy": formula_metrics["summary"]["mean_accuracy"], + "formula_exact_match_rate": formula_metrics["summary"]["exact_match_rate"], + "formula_evaluated": formula_metrics["summary"]["evaluated"], + "retrieval_recall_at_1": retrieval_metrics["summary"]["recall_at_1"], + "retrieval_recall_at_5": retrieval_metrics["summary"]["recall_at_5"], + "retrieval_mrr": retrieval_metrics["summary"]["mean_reciprocal_rank"], + "retrieval_query_count": retrieval_metrics["summary"]["query_count"], + "retrieval_evaluated": retrieval_metrics["summary"]["evaluated"], + "parser_contribution_counts": contribution["counts"], + "parser_contribution_fractions": contribution["fractions"], + "per_parser_metrics": per_parser, + } + documents.append(doc_record) + for parser_name, block in per_parser.items(): + per_parser_rows.append(_per_parser_row(path, parsed, dataset_document, parser_name, block)) + repair_rows.append(_repair_row(path, parsed, repair_success, disagreement)) + if layout_metrics["summary"]["evaluated"]: + layout_rows.append(_layout_row(path, parsed, dataset_document, layout_metrics)) + if table_metrics["summary"]["evaluated"]: + table_structure_rows.append(_table_structure_row(path, parsed, dataset_document, table_metrics)) + if formula_metrics["summary"]["evaluated"]: + formula_rows.append(_formula_row(path, parsed, dataset_document, formula_metrics)) + if retrieval_metrics["summary"]["evaluated"]: + retrieval_rows.append(_retrieval_row(path, parsed, retrieval_metrics)) + for parser_name, metrics in doc_record["parser_metrics"].items(): + parser_rows.append( + { + "source_path": str(path), + "doc_id": parsed.doc_id, + "parser": parser_name, + **metrics, + } + ) + chunk_rows.extend(_chunk_strategy_rows(path, parsed)) + structure_rows.append(structure_quality_record(parsed, str(path))) + chunk_quality_rows.append(chunking_quality_record(parsed, str(path))) + throughput_rows.append(throughput_record(parsed, str(path), elapsed_seconds)) + + config = bench_config + parser_contribution_summary = _aggregate_parser_contributions(documents) + summary = { + "dataset_name": dataset_name, + "dataset_root": str(input_root), + "document_count": len(documents), + "parser_contribution_summary": parser_contribution_summary, + "mean_quality_score": mean([doc["quality_score"] for doc in documents]) if documents else 0.0, + "mean_parser_disagreement_rate": _mean_value(documents, "parser_disagreement_rate"), + "mean_repair_resolution_rate": _mean_value(documents, "repair_resolution_rate"), + "mean_repair_regression_rate": _mean_value(documents, "repair_regression_rate"), + "mean_layout_f1": _mean_value( + [doc for doc in documents if doc.get("layout_evaluated")], "layout_f1" + ), + "mean_layout_class_agnostic_f1": _mean_value( + [doc for doc in documents if doc.get("layout_evaluated")], "layout_class_agnostic_f1" + ), + "layout_evaluated_count": sum(1 for doc in documents if doc.get("layout_evaluated")), + "mean_table_structure_score": _mean_value( + [doc for doc in documents if doc.get("table_structure_evaluated")], "table_structure_score" + ), + "mean_table_match_rate": _mean_value( + [doc for doc in documents if doc.get("table_structure_evaluated")], "table_match_rate" + ), + "table_structure_evaluated_count": sum(1 for doc in documents if doc.get("table_structure_evaluated")), + "mean_formula_cer": _mean_value( + [doc for doc in documents if doc.get("formula_evaluated")], "formula_cer" + ), + "mean_formula_accuracy": _mean_value( + [doc for doc in documents if doc.get("formula_evaluated")], "formula_accuracy" + ), + "formula_evaluated_count": sum(1 for doc in documents if doc.get("formula_evaluated")), + "mean_retrieval_recall_at_1": _mean_value( + [doc for doc in documents if doc.get("retrieval_evaluated")], "retrieval_recall_at_1" + ), + "mean_retrieval_recall_at_5": _mean_value( + [doc for doc in documents if doc.get("retrieval_evaluated")], "retrieval_recall_at_5" + ), + "mean_retrieval_mrr": _mean_value( + [doc for doc in documents if doc.get("retrieval_evaluated")], "retrieval_mrr" + ), + "retrieval_evaluated_count": sum(1 for doc in documents if doc.get("retrieval_evaluated")), + "documents": documents, + "parser_leaderboard": _parser_leaderboard(parser_rows), + "per_parser_gt_leaderboard": _per_parser_gt_leaderboard(per_parser_rows), + "chunk_strategy_leaderboard": _chunk_strategy_leaderboard(chunk_rows), + "structure_quality": score_structure_quality(structure_rows), + "chunking_quality": score_chunking_quality(chunk_quality_rows), + "throughput": summarize_throughput(throughput_rows), + "ablation_plan": ablation_plan(config, observed_chunk_strategies), + } + write_json(output_root / "results.json", summary) + write_json(output_root / "ablations.json", summary["ablation_plan"]) + _write_leaderboard_csv(output_root / "leaderboard.csv", summary["parser_leaderboard"]) + _write_parser_rows_csv(output_root / "parser_runs.csv", parser_rows) + _write_chunk_rows_csv(output_root / "chunk_runs.csv", chunk_rows) + _write_structure_rows_csv(output_root / "structure_runs.csv", structure_rows) + _write_chunk_quality_rows_csv(output_root / "chunk_quality.csv", chunk_quality_rows) + _write_throughput_rows_csv(output_root / "throughput_runs.csv", throughput_rows) + _write_repair_rows_csv(output_root / "repair_runs.csv", repair_rows) + _write_layout_rows_csv(output_root / "layout_runs.csv", layout_rows) + _write_table_structure_rows_csv(output_root / "table_structure_runs.csv", table_structure_rows) + _write_formula_rows_csv(output_root / "formula_runs.csv", formula_rows) + _write_retrieval_rows_csv(output_root / "retrieval_runs.csv", retrieval_rows) + _write_per_parser_rows_csv(output_root / "per_parser_metrics.csv", per_parser_rows) + _write_per_parser_gt_leaderboard_csv( + output_root / "per_parser_gt_leaderboard.csv", summary["per_parser_gt_leaderboard"] + ) + return summary + + +def _per_parser_gt_leaderboard(rows: list[dict]) -> list[dict]: + """Aggregate per-document per-parser rows into one leaderboard row per parser. + + A metric contributes to a parser's mean only when that parser actually + had a non-zero prediction count for that metric on that document; this + keeps "0.00 from no predictions" from dragging the mean down for parsers + that simply don't emit bboxes (text/markdown). The number of documents + contributing to each metric is reported alongside the mean. + """ + + grouped: dict[str, list[dict]] = {} + for row in rows: + grouped.setdefault(row["parser"], []).append(row) + + leaderboard: list[dict] = [] + for parser_name, parser_rows in grouped.items(): + layout_rows = [row for row in parser_rows if row.get("layout_evaluated")] + table_rows = [row for row in parser_rows if row.get("table_evaluated")] + formula_rows = [row for row in parser_rows if row.get("formula_evaluated")] + leaderboard.append( + { + "parser": parser_name, + "document_count": len(parser_rows), + "layout_evaluated_count": len(layout_rows), + "mean_layout_class_aware_f1": _mean_value(layout_rows, "layout_class_aware_f1"), + "mean_layout_class_agnostic_f1": _mean_value(layout_rows, "layout_class_agnostic_f1"), + "mean_layout_precision": _mean_value(layout_rows, "layout_class_aware_precision"), + "mean_layout_recall": _mean_value(layout_rows, "layout_class_aware_recall"), + "table_evaluated_count": len(table_rows), + "mean_table_structure_score": _mean_value(table_rows, "table_structure_score"), + "mean_table_match_rate": _mean_value(table_rows, "table_match_rate"), + "mean_table_cell_content_f1": _mean_value(table_rows, "table_cell_content_f1"), + "formula_evaluated_count": len(formula_rows), + "mean_formula_cer": _mean_value(formula_rows, "formula_cer"), + "mean_formula_accuracy": _mean_value(formula_rows, "formula_accuracy"), + "mean_formula_exact_match_rate": _mean_value(formula_rows, "formula_exact_match_rate"), + "mean_element_count": _mean_value(parser_rows, "element_count"), + "mean_table_count": _mean_value(parser_rows, "table_count"), + "mean_figure_count": _mean_value(parser_rows, "figure_count"), + } + ) + return sorted( + leaderboard, + key=lambda row: ( + row["mean_layout_class_aware_f1"], + row["mean_table_structure_score"], + -row["mean_formula_cer"], + ), + reverse=True, + ) + + +def _write_per_parser_gt_leaderboard_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "parser", + "document_count", + "layout_evaluated_count", + "mean_layout_class_aware_f1", + "mean_layout_class_agnostic_f1", + "mean_layout_precision", + "mean_layout_recall", + "table_evaluated_count", + "mean_table_structure_score", + "mean_table_match_rate", + "mean_table_cell_content_f1", + "formula_evaluated_count", + "mean_formula_cer", + "mean_formula_accuracy", + "mean_formula_exact_match_rate", + "mean_element_count", + "mean_table_count", + "mean_figure_count", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _compute_per_parser_block(dataset_document: DatasetDocument, parsed) -> dict[str, dict[str, Any]]: + layout_adapter = GROUND_TRUTH_ADAPTERS.get(dataset_document.dataset_id) + table_adapter = TABLE_TRUTH_ADAPTERS.get(dataset_document.dataset_id) + formula_adapter = FORMULA_TRUTH_ADAPTERS.get(dataset_document.dataset_id) + gt = dataset_document.ground_truth + layout_truths = layout_adapter(gt) if (layout_adapter and gt is not None) else None + table_truths = table_adapter(gt) if (table_adapter and gt is not None) else None + formula_truths = formula_adapter(gt) if (formula_adapter and gt is not None) else None + if not (layout_truths or table_truths or formula_truths): + return {} + return compute_per_parser_metrics( + parsed, + layout_truths=layout_truths or None, + table_truths=table_truths or None, + formula_truths=formula_truths or None, + ) + + +def _per_parser_row( + path: Path, + parsed, + dataset_document: DatasetDocument, + parser_name: str, + block: dict[str, Any], +) -> dict[str, Any]: + layout = block.get("layout") or {} + table = block.get("table_structure") or {} + formula = block.get("formula") or {} + return { + "source_path": str(path), + "doc_id": parsed.doc_id, + "dataset_id": dataset_document.dataset_id, + "parser": parser_name, + "element_count": int(block.get("element_count", 0)), + "table_count": int(block.get("table_count", 0)), + "figure_count": int(block.get("figure_count", 0)), + "layout_evaluated": "layout" in block, + "table_evaluated": "table_structure" in block, + "formula_evaluated": "formula" in block, + "layout_prediction_count": int(layout.get("prediction_count", 0)), + "layout_class_aware_f1": float(layout.get("class_aware_f1", 0.0)), + "layout_class_aware_precision": float(layout.get("class_aware_precision", 0.0)), + "layout_class_aware_recall": float(layout.get("class_aware_recall", 0.0)), + "layout_class_agnostic_f1": float(layout.get("class_agnostic_f1", 0.0)), + "table_structure_score": float(table.get("mean_table_score", 0.0)), + "table_match_rate": float(table.get("table_match_rate", 0.0)), + "table_cell_content_f1": float(table.get("mean_cell_content_f1", 0.0)), + "formula_cer": float(formula.get("mean_cer", 0.0)) if formula else 0.0, + "formula_accuracy": float(formula.get("mean_accuracy", 0.0)) if formula else 0.0, + "formula_exact_match_rate": float(formula.get("exact_match_rate", 0.0)) if formula else 0.0, + } + + +def _write_per_parser_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "dataset_id", + "parser", + "element_count", + "table_count", + "figure_count", + "layout_evaluated", + "table_evaluated", + "formula_evaluated", + "layout_prediction_count", + "layout_class_aware_f1", + "layout_class_aware_precision", + "layout_class_aware_recall", + "layout_class_agnostic_f1", + "table_structure_score", + "table_match_rate", + "table_cell_content_f1", + "formula_cer", + "formula_accuracy", + "formula_exact_match_rate", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _retrieval_metrics(parsed, config: dict | None = None) -> dict: + from zsgdp.benchmarks.embedding_retriever import build_retriever + + retriever = build_retriever(config) if config else None + run = run_retrieval_for_document(parsed, retriever=retriever) + if not run["evaluated"]: + return { + "summary": { + "evaluated": False, + "query_count": 0, + "recall_at_1": 0.0, + "recall_at_5": 0.0, + "mean_reciprocal_rank": 0.0, + }, + "metrics": None, + "reason": run.get("reason"), + } + metrics = compute_retrieval_metrics( + ((result["retrieved"], result["truths"]) for result in run["results"]), + ) + return { + "summary": { + "evaluated": True, + "query_count": int(metrics["query_count"]), + "recall_at_1": float(metrics["recall_at_k"].get(1, 0.0)), + "recall_at_5": float(metrics["recall_at_k"].get(5, 0.0)), + "mean_reciprocal_rank": float(metrics["mean_reciprocal_rank"]), + }, + "metrics": metrics, + "reason": None, + } + + +def _retrieval_row(path: Path, parsed, retrieval_metrics: dict) -> dict: + detail = retrieval_metrics["metrics"] or {} + recall = detail.get("recall_at_k", {}) + return { + "source_path": str(path), + "doc_id": parsed.doc_id, + "query_count": int(detail.get("query_count", 0)), + "recall_at_1": float(recall.get(1, 0.0)), + "recall_at_3": float(recall.get(3, 0.0)), + "recall_at_5": float(recall.get(5, 0.0)), + "mean_reciprocal_rank": float(detail.get("mean_reciprocal_rank", 0.0)), + "citation_accuracy_at_5": float(detail.get("citation_accuracy_at_k", {}).get(5, 0.0)), + } + + +def _parser_contribution(parsed) -> dict[str, Any]: + """Count which parser produced each merged element/table/figure. + + Counts are over the *post-merge* output, not the pre-merge candidates. + This is a "contribution" view (which parser's output survived) rather + than an "ablation" view (which parser would do best alone). + """ + + counts: dict[str, int] = {} + for element in parsed.elements: + counts[element.source_parser] = counts.get(element.source_parser, 0) + 1 + for table in parsed.tables: + counts[table.source_parser] = counts.get(table.source_parser, 0) + 1 + for figure in parsed.figures: + counts[figure.source_parser] = counts.get(figure.source_parser, 0) + 1 + total = sum(counts.values()) + fractions = {parser: (count / total) for parser, count in counts.items()} if total else {} + return {"counts": counts, "fractions": fractions, "total": total} + + +def _aggregate_parser_contributions(documents: list[dict]) -> dict[str, Any]: + parser_totals: dict[str, int] = {} + grand_total = 0 + for doc in documents: + counts = doc.get("parser_contribution_counts") or {} + if not isinstance(counts, dict): + continue + for parser, count in counts.items(): + parser_totals[parser] = parser_totals.get(parser, 0) + int(count) + grand_total += int(count) + fractions = {parser: (count / grand_total) for parser, count in parser_totals.items()} if grand_total else {} + return {"counts": dict(sorted(parser_totals.items())), "fractions": dict(sorted(fractions.items())), "total": grand_total} + + +def _write_retrieval_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "query_count", + "recall_at_1", + "recall_at_3", + "recall_at_5", + "mean_reciprocal_rank", + "citation_accuracy_at_5", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _table_structure_metrics(dataset_document: DatasetDocument, parsed) -> dict: + adapter = TABLE_TRUTH_ADAPTERS.get(dataset_document.dataset_id) + if adapter is None or dataset_document.ground_truth is None: + return { + "summary": {"evaluated": False, "mean_table_score": 1.0, "table_match_rate": 1.0}, + "metrics": None, + "reason": "no_table_truth_adapter" if adapter is None else "no_ground_truth", + } + truths = adapter(dataset_document.ground_truth) + predictions = parsed_table_records(parsed) + if not truths and not predictions: + return { + "summary": {"evaluated": False, "mean_table_score": 1.0, "table_match_rate": 1.0}, + "metrics": None, + "reason": "no_truths_and_no_predictions", + } + metrics = compute_table_structure_score(predictions, truths) + return { + "summary": { + "evaluated": True, + "mean_table_score": float(metrics["mean_table_score"]), + "table_match_rate": float(metrics["table_match_rate"]), + }, + "metrics": metrics, + "reason": None, + } + + +def _formula_extraction_metrics(dataset_document: DatasetDocument, parsed) -> dict: + adapter = FORMULA_TRUTH_ADAPTERS.get(dataset_document.dataset_id) + if adapter is None or dataset_document.ground_truth is None: + return { + "summary": {"evaluated": False, "mean_cer": 0.0, "mean_accuracy": 1.0, "exact_match_rate": 1.0}, + "metrics": None, + "reason": "no_formula_truth_adapter" if adapter is None else "no_ground_truth", + } + truths = adapter(dataset_document.ground_truth) + predictions = parsed_formula_records(parsed) + if not truths and not predictions: + return { + "summary": {"evaluated": False, "mean_cer": 0.0, "mean_accuracy": 1.0, "exact_match_rate": 1.0}, + "metrics": None, + "reason": "no_truths_and_no_predictions", + } + metrics = compute_formula_extraction(predictions, truths) + return { + "summary": { + "evaluated": True, + "mean_cer": float(metrics["mean_cer"]), + "mean_accuracy": float(metrics["mean_accuracy"]), + "exact_match_rate": float(metrics["exact_match_rate"]), + }, + "metrics": metrics, + "reason": None, + } + + +def _table_structure_row(path: Path, parsed, dataset_document: DatasetDocument, metrics: dict) -> dict: + detail = metrics["metrics"] or {} + return { + "source_path": str(path), + "doc_id": parsed.doc_id, + "dataset_id": dataset_document.dataset_id, + "prediction_count": int(detail.get("prediction_count", 0)), + "truth_count": int(detail.get("truth_count", 0)), + "matched_pair_count": int(detail.get("matched_pair_count", 0)), + "table_match_rate": float(detail.get("table_match_rate", 0.0)), + "mean_table_score": float(detail.get("mean_table_score", 0.0)), + "mean_shape_similarity": float(detail.get("mean_shape_similarity", 0.0)), + "mean_cell_content_f1": float(detail.get("mean_cell_content_f1", 0.0)), + "table_count_delta": int(detail.get("table_count_delta", 0)), + } + + +def _formula_row(path: Path, parsed, dataset_document: DatasetDocument, metrics: dict) -> dict: + detail = metrics["metrics"] or {} + return { + "source_path": str(path), + "doc_id": parsed.doc_id, + "dataset_id": dataset_document.dataset_id, + "prediction_count": int(detail.get("prediction_count", 0)), + "truth_count": int(detail.get("truth_count", 0)), + "matched_pair_count": int(detail.get("matched_pair_count", 0)), + "mean_cer": float(detail.get("mean_cer", 1.0)), + "mean_accuracy": float(detail.get("mean_accuracy", 0.0)), + "exact_match_rate": float(detail.get("exact_match_rate", 0.0)), + "formula_precision": float(detail.get("formula_precision", 0.0)), + "formula_recall": float(detail.get("formula_recall", 0.0)), + } + + +def _write_table_structure_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "dataset_id", + "prediction_count", + "truth_count", + "matched_pair_count", + "table_match_rate", + "mean_table_score", + "mean_shape_similarity", + "mean_cell_content_f1", + "table_count_delta", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _write_formula_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "dataset_id", + "prediction_count", + "truth_count", + "matched_pair_count", + "mean_cer", + "mean_accuracy", + "exact_match_rate", + "formula_precision", + "formula_recall", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _layout_metrics(dataset_document: DatasetDocument, parsed) -> dict: + adapter = GROUND_TRUTH_ADAPTERS.get(dataset_document.dataset_id) + if adapter is None or dataset_document.ground_truth is None: + return { + "summary": {"evaluated": False, "class_aware_f1": 0.0, "class_agnostic_f1": 0.0}, + "metrics": None, + "reason": "no_ground_truth_adapter" if adapter is None else "no_ground_truth", + } + truths = adapter(dataset_document.ground_truth) + predictions = parsed_layout_predictions(parsed) + if not truths and not predictions: + return { + "summary": {"evaluated": False, "class_aware_f1": 0.0, "class_agnostic_f1": 0.0}, + "metrics": None, + "reason": "no_truths_and_no_predictions", + } + metrics = compute_layout_f1(predictions, truths) + return { + "summary": { + "evaluated": True, + "class_aware_f1": float(metrics["class_aware"]["f1"]), + "class_agnostic_f1": float(metrics["class_agnostic"]["f1"]), + }, + "metrics": metrics, + "reason": None, + } + + +def _layout_row(path: Path, parsed, dataset_document: DatasetDocument, layout_metrics: dict) -> dict: + metrics = layout_metrics["metrics"] or {} + class_aware = metrics.get("class_aware", {}) + class_agnostic = metrics.get("class_agnostic", {}) + return { + "source_path": str(path), + "doc_id": parsed.doc_id, + "dataset_id": dataset_document.dataset_id, + "iou_threshold": float(metrics.get("iou_threshold", 0.5)), + "prediction_count": int(metrics.get("prediction_count", 0)), + "truth_count": int(metrics.get("truth_count", 0)), + "class_aware_precision": float(class_aware.get("precision", 0.0)), + "class_aware_recall": float(class_aware.get("recall", 0.0)), + "class_aware_f1": float(class_aware.get("f1", 0.0)), + "class_agnostic_precision": float(class_agnostic.get("precision", 0.0)), + "class_agnostic_recall": float(class_agnostic.get("recall", 0.0)), + "class_agnostic_f1": float(class_agnostic.get("f1", 0.0)), + } + + +def _write_layout_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "dataset_id", + "iou_threshold", + "prediction_count", + "truth_count", + "class_aware_precision", + "class_aware_recall", + "class_aware_f1", + "class_agnostic_precision", + "class_agnostic_recall", + "class_agnostic_f1", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _repair_row(path: Path, parsed, repair_success: dict, disagreement: dict) -> dict: + return { + "source_path": str(path), + "doc_id": parsed.doc_id, + "candidate_count": int(disagreement.get("candidate_count", 0)), + "parser_disagreement_rate": float(disagreement.get("disagreement_rate", 0.0)), + "conflict_count": int(disagreement.get("conflict_count", 0)), + "iteration_count": int(repair_success.get("iteration_count", 0)), + "total_actions": int(repair_success.get("total_actions", 0)), + "score_delta": float(repair_success.get("score_delta", 0.0)), + "pre_repair_blocking_count": int(repair_success.get("pre_repair_blocking_count", 0)), + "post_repair_blocking_count": int(repair_success.get("post_repair_blocking_count", 0)), + "resolved_blocking_count": int(repair_success.get("resolved_blocking_count", 0)), + "regressed_blocking_count": int(repair_success.get("regressed_blocking_count", 0)), + "repair_resolution_rate": float(repair_success.get("repair_resolution_rate", 1.0)), + "repair_regression_rate": float(repair_success.get("repair_regression_rate", 0.0)), + } + + +def _write_repair_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "candidate_count", + "parser_disagreement_rate", + "conflict_count", + "iteration_count", + "total_actions", + "score_delta", + "pre_repair_blocking_count", + "post_repair_blocking_count", + "resolved_blocking_count", + "regressed_blocking_count", + "repair_resolution_rate", + "repair_regression_rate", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _chunk_strategy_counts(chunks: list) -> dict[str, int]: + counts: dict[str, int] = {} + for chunk in chunks: + counts[chunk.strategy] = counts.get(chunk.strategy, 0) + 1 + return counts + + +def _chunk_quality_metrics(metrics: dict) -> dict: + keys = [ + "chunk_count", + "parent_chunk_count", + "child_chunk_count", + "avg_chunk_tokens", + "max_chunk_tokens", + "table_chunk_coverage", + "figure_chunk_coverage", + ] + return {key: metrics[key] for key in keys if key in metrics} + + +def _chunk_strategy_rows(path: Path, parsed) -> list[dict]: + grouped: dict[str, list] = {} + for chunk in parsed.chunks: + grouped.setdefault(chunk.strategy, []).append(chunk) + + rows: list[dict] = [] + for strategy, chunks in sorted(grouped.items()): + token_counts = [chunk.token_count for chunk in chunks] + rows.append( + { + "source_path": str(path), + "doc_id": parsed.doc_id, + "strategy": strategy, + "quality_score": parsed.quality_report.score, + "chunk_count": len(chunks), + "avg_tokens": mean(token_counts) if token_counts else 0.0, + "max_tokens": max(token_counts) if token_counts else 0, + "table_linked_chunks": sum(1 for chunk in chunks if chunk.table_ids), + "figure_linked_chunks": sum(1 for chunk in chunks if chunk.figure_ids), + "visual_context_chunks": sum(1 for chunk in chunks if chunk.requires_visual_context), + } + ) + return rows + + +def _chunk_strategy_leaderboard(rows: list[dict]) -> list[dict]: + grouped: dict[str, list[dict]] = {} + for row in rows: + grouped.setdefault(row["strategy"], []).append(row) + + leaderboard: list[dict] = [] + for strategy, strategy_rows in grouped.items(): + leaderboard.append( + { + "strategy": strategy, + "runs": len(strategy_rows), + "total_chunks": sum(int(row.get("chunk_count", 0)) for row in strategy_rows), + "mean_chunk_count": _mean_value(strategy_rows, "chunk_count"), + "mean_avg_tokens": _mean_value(strategy_rows, "avg_tokens"), + "mean_max_tokens": _mean_value(strategy_rows, "max_tokens"), + "mean_quality_score": _mean_value(strategy_rows, "quality_score"), + "total_table_linked_chunks": sum(int(row.get("table_linked_chunks", 0)) for row in strategy_rows), + "total_figure_linked_chunks": sum(int(row.get("figure_linked_chunks", 0)) for row in strategy_rows), + "total_visual_context_chunks": sum(int(row.get("visual_context_chunks", 0)) for row in strategy_rows), + } + ) + return sorted(leaderboard, key=lambda row: (row["mean_quality_score"], row["total_chunks"]), reverse=True) + + +def _parser_leaderboard(rows: list[dict]) -> list[dict]: + grouped: dict[str, list[dict]] = {} + for row in rows: + grouped.setdefault(row["parser"], []).append(row) + + leaderboard: list[dict] = [] + for parser_name, parser_rows in grouped.items(): + successes = [row for row in parser_rows if not row.get("failed")] + leaderboard.append( + { + "parser": parser_name, + "runs": len(parser_rows), + "successes": len(successes), + "failures": len(parser_rows) - len(successes), + "mean_elapsed_seconds": _mean_value(successes, "elapsed_seconds"), + "mean_text_coverage_ratio": _mean_value(successes, "text_coverage_ratio"), + "mean_element_count": _mean_value(successes, "element_count"), + "mean_table_count": _mean_value(successes, "table_count"), + "mean_figure_count": _mean_value(successes, "figure_count"), + "mean_valid_table_ratio": _mean_value(successes, "valid_table_ratio"), + } + ) + return sorted( + leaderboard, + key=lambda row: (row["mean_text_coverage_ratio"], row["mean_valid_table_ratio"], -row["mean_elapsed_seconds"]), + reverse=True, + ) + + +def _mean_value(rows: list[dict], key: str) -> float: + values = [float(row[key]) for row in rows if row.get(key) is not None] + return mean(values) if values else 0.0 + + +def _write_leaderboard_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "parser", + "runs", + "successes", + "failures", + "mean_elapsed_seconds", + "mean_text_coverage_ratio", + "mean_element_count", + "mean_table_count", + "mean_figure_count", + "mean_valid_table_ratio", + ] + _write_csv(path, rows, fieldnames) + + +def _write_parser_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "parser", + "failed", + "error", + "elapsed_seconds", + "page_count", + "element_count", + "table_count", + "figure_count", + "text_chars", + "expected_text_chars", + "text_coverage_ratio", + "valid_table_ratio", + "has_bboxes", + "has_page_images", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _write_chunk_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "strategy", + "quality_score", + "chunk_count", + "avg_tokens", + "max_tokens", + "table_linked_chunks", + "figure_linked_chunks", + "visual_context_chunks", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _write_structure_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "table_count", + "valid_table_count", + "table_exact", + "figure_count", + "captioned_figure_count", + "figure_caption_correct", + "reading_order_issue_count", + "reading_order_health", + "document_text_coverage", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _write_chunk_quality_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "chunk_count", + "boundary_precision", + "parent_child_resolution", + "provenance_completeness", + "retrieval_readiness", + "table_chunk_coverage", + "figure_chunk_coverage", + "avg_tokens", + "max_tokens", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _write_throughput_rows_csv(path: Path, rows: list[dict]) -> None: + fieldnames = [ + "source_path", + "doc_id", + "page_count", + "elapsed_seconds", + "pages_per_second", + "elements_per_second", + "chunks_per_second", + "gpu_task_count", + "runtime_device", + "max_gpu_seconds_per_doc", + ] + normalized = [{field: row.get(field, "") for field in fieldnames} for row in rows] + _write_csv(path, normalized, fieldnames) + + +def _write_csv(path: Path, rows: list[dict], fieldnames: list[str]) -> None: + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) diff --git a/zsgdp/benchmarks/per_parser_metrics.py b/zsgdp/benchmarks/per_parser_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..c0535e65dfe2e2c3aad801813c3bdb4c07744f52 --- /dev/null +++ b/zsgdp/benchmarks/per_parser_metrics.py @@ -0,0 +1,108 @@ +"""Per-parser GT-comparison metrics within a single merged run. + +Definitions (pinned): + +- This module reads parsed.provenance["candidates"] (a per-parser snapshot of + pre-merge elements/tables/figures, populated by the pipeline) and computes + layout F1 / table structure / formula CER for *each* parser against the + same ground truth that was used for the document-level metrics. +- Returned shape is one block per parser. A parser appears only if it + produced a candidate; failed parsers from parser_metrics are not included. +- The GT side is computed once and shared across parsers — so per-parser + comparisons are apples-to-apples on the same truth set. +- This is distinct from the ablation runner: ablation reruns the parse with + a different parser selection (changes routing, repair, merger inputs); + per-parser-from-merged keeps the routing constant and asks "which parser + contributed the more accurate elements to this merged run." +""" + +from __future__ import annotations + +from typing import Any + +from zsgdp.benchmarks.ground_truth import ( + formula_records_from_items, + layout_predictions_from_items, + table_records_from_items, +) +from zsgdp.verify.formula_extraction import compute_formula_extraction +from zsgdp.verify.layout_f1 import compute_layout_f1 +from zsgdp.verify.table_structure import compute_table_structure_score + + +def compute_per_parser_metrics( + parsed, + *, + layout_truths: list[dict[str, Any]] | None = None, + table_truths: list[dict[str, Any]] | None = None, + formula_truths: list[dict[str, Any]] | None = None, +) -> dict[str, dict[str, Any]]: + """Compute layout/table/formula metrics per parser candidate. + + All three truth lists are optional; whichever is non-None drives that + metric. When a truth list is None or empty, the corresponding metric + block is omitted from the per-parser record. + """ + + candidates = (parsed.provenance.get("candidates") or {}) if hasattr(parsed, "provenance") else {} + if not isinstance(candidates, dict): + return {} + + out: dict[str, dict[str, Any]] = {} + for parser_name, candidate in candidates.items(): + if not isinstance(candidate, dict): + continue + elements = candidate.get("elements") or [] + tables = candidate.get("tables") or [] + figures = candidate.get("figures") or [] + + block: dict[str, Any] = { + "parser": parser_name, + "element_count": len(elements), + "table_count": len(tables), + "figure_count": len(figures), + } + + if layout_truths: + predictions = layout_predictions_from_items( + elements=elements, + tables=tables, + figures=figures, + ) + metrics = compute_layout_f1(predictions, layout_truths) + block["layout"] = { + "prediction_count": metrics["prediction_count"], + "truth_count": metrics["truth_count"], + "class_aware_f1": metrics["class_aware"]["f1"], + "class_aware_precision": metrics["class_aware"]["precision"], + "class_aware_recall": metrics["class_aware"]["recall"], + "class_agnostic_f1": metrics["class_agnostic"]["f1"], + } + + if table_truths: + predictions = table_records_from_items(elements=elements, tables=tables) + metrics = compute_table_structure_score(predictions, table_truths) + block["table_structure"] = { + "prediction_count": metrics["prediction_count"], + "truth_count": metrics["truth_count"], + "matched_pair_count": metrics["matched_pair_count"], + "table_match_rate": metrics["table_match_rate"], + "mean_table_score": metrics["mean_table_score"], + "mean_shape_similarity": metrics["mean_shape_similarity"], + "mean_cell_content_f1": metrics["mean_cell_content_f1"], + } + + if formula_truths: + predictions = formula_records_from_items(elements=elements) + metrics = compute_formula_extraction(predictions, formula_truths) + block["formula"] = { + "prediction_count": metrics["prediction_count"], + "truth_count": metrics["truth_count"], + "matched_pair_count": metrics["matched_pair_count"], + "mean_cer": metrics["mean_cer"], + "mean_accuracy": metrics["mean_accuracy"], + "exact_match_rate": metrics["exact_match_rate"], + } + + out[parser_name] = block + return out diff --git a/zsgdp/benchmarks/retrieval.py b/zsgdp/benchmarks/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1475ff031607a2b67d621341f26c66be2a84a2 --- /dev/null +++ b/zsgdp/benchmarks/retrieval.py @@ -0,0 +1,223 @@ +"""Lexical retriever and synthetic QA generator for retrieval benchmarking. + +The retriever is intentionally model-free: TF-IDF cosine similarity over +chunk text. Real embedders can be plugged in by passing a custom retriever +to run_retrieval_benchmark — the metric module zsgdp.verify.retrieval is +agnostic to how the rankings were produced. + +QA generation is also model-free. We pick "distinctive sentences" — sentences +that are unique to a single chunk after lowercasing — as queries. The truth +chunk_id is the chunk that owns the sentence. Limitations are explicit: +verbatim retrieval is easier than real RAG-QA (which requires paraphrasing), +so the metrics reported here are a *lower-bound proxy* for retrieval quality. +A retriever that fails verbatim retrieval will fail real RAG; passing it is +necessary but not sufficient. +""" + +from __future__ import annotations + +from collections import Counter +from dataclasses import dataclass, field +import math +import re +from typing import Any, Iterable, Protocol, Sequence + +from zsgdp.schema import Chunk, ParsedDocument + + +@dataclass(slots=True) +class RetrievalQuery: + query_id: str + text: str + truths: list[str] + metadata: dict[str, Any] = field(default_factory=dict) + + +class Retriever(Protocol): + def index(self, chunks: Sequence[Chunk]) -> None: ... + def query(self, text: str, *, top_k: int) -> list[str]: ... + + +class LexicalRetriever: + """TF-IDF cosine similarity over chunk text.""" + + def __init__(self) -> None: + self._chunk_ids: list[str] = [] + self._vectors: list[dict[str, float]] = [] + self._norms: list[float] = [] + self._idf: dict[str, float] = {} + + def index(self, chunks: Sequence[Chunk]) -> None: + self._chunk_ids = [] + self._vectors = [] + self._norms = [] + token_lists: list[list[str]] = [] + document_frequency: Counter[str] = Counter() + for chunk in chunks: + tokens = _tokenize(chunk.text) + token_lists.append(tokens) + document_frequency.update(set(tokens)) + self._chunk_ids.append(chunk.chunk_id) + + chunk_count = max(1, len(token_lists)) + self._idf = { + token: math.log((1 + chunk_count) / (1 + frequency)) + 1.0 + for token, frequency in document_frequency.items() + } + for tokens in token_lists: + vector = self._tf_idf_vector(tokens) + self._vectors.append(vector) + self._norms.append(math.sqrt(sum(value * value for value in vector.values()))) + + def query(self, text: str, *, top_k: int) -> list[str]: + if not self._vectors: + return [] + tokens = _tokenize(text) + if not tokens: + return [] + query_vector = self._tf_idf_vector(tokens) + query_norm = math.sqrt(sum(value * value for value in query_vector.values())) + if query_norm == 0: + return [] + scored: list[tuple[float, int]] = [] + for index, doc_vector in enumerate(self._vectors): + doc_norm = self._norms[index] + if doc_norm == 0: + continue + dot = 0.0 + short, long = ( + (query_vector, doc_vector) if len(query_vector) < len(doc_vector) else (doc_vector, query_vector) + ) + for token, value in short.items(): + long_value = long.get(token) + if long_value is not None: + dot += value * long_value + if dot <= 0: + continue + score = dot / (query_norm * doc_norm) + scored.append((score, index)) + scored.sort(key=lambda item: (-item[0], item[1])) + return [self._chunk_ids[index] for _score, index in scored[:top_k]] + + def _tf_idf_vector(self, tokens: Sequence[str]) -> dict[str, float]: + if not tokens: + return {} + counts = Counter(tokens) + length = float(sum(counts.values())) + return {token: (count / length) * self._idf.get(token, 1.0) for token, count in counts.items()} + + +def synthesize_qa_set( + parsed: ParsedDocument, + *, + max_queries_per_doc: int = 20, + min_sentence_tokens: int = 6, +) -> list[RetrievalQuery]: + chunks_by_id = {chunk.chunk_id: chunk for chunk in parsed.chunks if chunk.text.strip()} + if not chunks_by_id: + return [] + + sentences_by_chunk: dict[str, list[str]] = {} + sentence_owners: dict[str, set[str]] = {} + for chunk_id, chunk in chunks_by_id.items(): + sentences = _split_sentences(chunk.text) + sentences_by_chunk[chunk_id] = sentences + for sentence in sentences: + normalized = _normalize_sentence(sentence) + if not normalized: + continue + if len(normalized.split()) < min_sentence_tokens: + continue + sentence_owners.setdefault(normalized, set()).add(chunk_id) + + queries: list[RetrievalQuery] = [] + for chunk_id, sentences in sentences_by_chunk.items(): + if len(queries) >= max_queries_per_doc: + break + for sentence in sentences: + normalized = _normalize_sentence(sentence) + if not normalized: + continue + owners = sentence_owners.get(normalized) or set() + if owners != {chunk_id}: + continue + queries.append( + RetrievalQuery( + query_id=f"{chunk_id}_q{len(queries) + 1}", + text=sentence.strip(), + truths=[chunk_id], + metadata={ + "source_chunk_id": chunk_id, + "page_start": chunks_by_id[chunk_id].page_start, + "page_end": chunks_by_id[chunk_id].page_end, + }, + ) + ) + break # one query per chunk + return queries + + +def run_retrieval_for_document( + parsed: ParsedDocument, + *, + retriever: Retriever | None = None, + queries: Sequence[RetrievalQuery] | None = None, + top_k: int = 5, + max_queries_per_doc: int = 20, +) -> dict[str, Any]: + chunks = [chunk for chunk in parsed.chunks if chunk.text.strip()] + if queries is None: + queries = synthesize_qa_set(parsed, max_queries_per_doc=max_queries_per_doc) + queries = list(queries) + if not chunks or not queries: + return { + "evaluated": False, + "query_count": 0, + "results": [], + "reason": "no_chunks" if not chunks else "no_queries", + } + retriever = retriever or LexicalRetriever() + retriever.index(chunks) + results: list[dict[str, Any]] = [] + for query in queries: + retrieved = retriever.query(query.text, top_k=top_k) + results.append( + { + "query_id": query.query_id, + "truths": list(query.truths), + "retrieved": retrieved, + "metadata": dict(query.metadata), + } + ) + return { + "evaluated": True, + "query_count": len(results), + "top_k": top_k, + "results": results, + "reason": None, + } + + +_SENTENCE_BOUNDARY = re.compile(r"(?<=[.!?])\s+") +_TOKEN_RE = re.compile(r"[a-zA-Z0-9]+") + + +def _split_sentences(text: str) -> list[str]: + parts: list[str] = [] + for line in text.splitlines(): + line = line.strip() + if not line: + continue + for sentence in _SENTENCE_BOUNDARY.split(line): + stripped = sentence.strip() + if stripped: + parts.append(stripped) + return parts + + +def _tokenize(text: str) -> list[str]: + return [token.lower() for token in _TOKEN_RE.findall(text or "")] + + +def _normalize_sentence(sentence: str) -> str: + return " ".join(_tokenize(sentence)) diff --git a/zsgdp/benchmarks/structure_quality.py b/zsgdp/benchmarks/structure_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..d07916f712a276bc01eb2f1eaaaa24d42dd4cf75 --- /dev/null +++ b/zsgdp/benchmarks/structure_quality.py @@ -0,0 +1,39 @@ +"""Document-structure benchmark helpers.""" + +from __future__ import annotations + +from zsgdp.schema import ParsedDocument +from zsgdp.verify.table_quality import markdown_table_is_valid + + +def score_structure_quality(results: list[dict]) -> dict[str, float]: + if not results: + return {"table_exactness": 0.0, "figure_caption_accuracy": 0.0, "reading_order_health": 0.0} + return { + "table_exactness": sum(float(item.get("table_exact", 0.0)) for item in results) / len(results), + "figure_caption_accuracy": sum(float(item.get("figure_caption_correct", 0.0)) for item in results) / len(results), + "reading_order_health": sum(float(item.get("reading_order_health", 0.0)) for item in results) / len(results), + } + + +def structure_quality_record(parsed: ParsedDocument, source_path: str) -> dict: + valid_tables = sum(1 for table in parsed.tables if markdown_table_is_valid(table.markdown)) + table_count = len(parsed.tables) + captioned_figures = sum(1 for figure in parsed.figures if figure.caption or figure.vlm_description) + figure_count = len(parsed.figures) + reading_order_issues = sum(1 for issue in parsed.quality_report.issues if issue.issue_type == "reading_order_failure") + page_count = len(parsed.pages) + + return { + "source_path": source_path, + "doc_id": parsed.doc_id, + "table_count": table_count, + "valid_table_count": valid_tables, + "table_exact": (valid_tables / table_count) if table_count else 1.0, + "figure_count": figure_count, + "captioned_figure_count": captioned_figures, + "figure_caption_correct": (captioned_figures / figure_count) if figure_count else 1.0, + "reading_order_issue_count": reading_order_issues, + "reading_order_health": 1.0 - min(reading_order_issues / page_count, 1.0) if page_count else 1.0, + "document_text_coverage": float(parsed.quality_report.metrics.get("document_text_coverage", 0.0)), + } diff --git a/zsgdp/benchmarks/throughput.py b/zsgdp/benchmarks/throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..60ce791b4619639a4384c2785f982a4291de4ffb --- /dev/null +++ b/zsgdp/benchmarks/throughput.py @@ -0,0 +1,54 @@ +"""Throughput benchmark helpers.""" + +from __future__ import annotations + + +def pages_per_second(page_count: int, elapsed_seconds: float) -> float: + if elapsed_seconds <= 0: + return 0.0 + return page_count / elapsed_seconds + + +def throughput_record(parsed, source_path: str, elapsed_seconds: float) -> dict: + page_count = max(len(parsed.pages), 1) + return { + "source_path": source_path, + "doc_id": parsed.doc_id, + "page_count": page_count, + "elapsed_seconds": elapsed_seconds, + "pages_per_second": pages_per_second(page_count, elapsed_seconds), + "elements_per_second": _per_second(len(parsed.elements), elapsed_seconds), + "chunks_per_second": _per_second(len(parsed.chunks), elapsed_seconds), + "gpu_task_count": len(parsed.provenance.get("gpu_tasks", [])), + "runtime_device": parsed.provenance.get("gpu_runtime", {}).get("device"), + "max_gpu_seconds_per_doc": parsed.provenance.get("gpu_runtime", {}).get("max_gpu_seconds_per_doc"), + } + + +def summarize_throughput(rows: list[dict]) -> dict[str, float]: + if not rows: + return { + "document_count": 0, + "total_pages": 0, + "total_elapsed_seconds": 0.0, + "mean_pages_per_second": 0.0, + "mean_chunks_per_second": 0.0, + } + return { + "document_count": len(rows), + "total_pages": sum(int(row.get("page_count", 0)) for row in rows), + "total_elapsed_seconds": sum(float(row.get("elapsed_seconds", 0.0)) for row in rows), + "mean_pages_per_second": _mean(rows, "pages_per_second"), + "mean_chunks_per_second": _mean(rows, "chunks_per_second"), + } + + +def _per_second(count: int, elapsed_seconds: float) -> float: + if elapsed_seconds <= 0: + return 0.0 + return count / elapsed_seconds + + +def _mean(rows: list[dict], key: str) -> float: + values = [float(row.get(key, 0.0)) for row in rows] + return sum(values) / len(values) if values else 0.0 diff --git a/zsgdp/chunking/__init__.py b/zsgdp/chunking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..518cfe432dc5920099f469c09e04c01dcf0fc870 --- /dev/null +++ b/zsgdp/chunking/__init__.py @@ -0,0 +1,4 @@ +from zsgdp.chunking.chunker import build_agentic_chunks +from zsgdp.chunking.planner import ChunkingPlan, plan_chunking + +__all__ = ["ChunkingPlan", "build_agentic_chunks", "plan_chunking"] diff --git a/zsgdp/chunking/agentic.py b/zsgdp/chunking/agentic.py new file mode 100644 index 0000000000000000000000000000000000000000..5abbe24827aa6a70868f3d947974474232720cc7 --- /dev/null +++ b/zsgdp/chunking/agentic.py @@ -0,0 +1,55 @@ +"""Agentic/proposition chunking hooks.""" + +from __future__ import annotations + +import re + + +def proposition_prompt(section_text: str) -> str: + return ( + "Decompose this document section into faithful atomic propositions. " + "Preserve citations, units, names, dates, and constraints. Do not invent facts.\n\n" + f"{section_text}" + ) + + +def agentic_chunking_recommended(profile_metadata: dict, quality_metrics: dict) -> bool: + high_complexity = bool(profile_metadata.get("visual_complexity_high") or profile_metadata.get("high_value_low_volume")) + weak_quality = float(quality_metrics.get("document_text_coverage", 1.0)) < 0.75 + return high_complexity or weak_quality + + +def extract_atomic_propositions(text: str, *, max_propositions: int = 8) -> list[str]: + """Deterministically approximate proposition chunks from sentence units.""" + + normalized = text.replace("\r", "\n") + candidates: list[str] = [] + for line in normalized.splitlines(): + stripped = re.sub(r"^\s*[-*0-9.)]+\s*", "", line).strip() + if stripped: + candidates.extend(_split_claims(stripped)) + if not candidates: + candidates.extend(_split_claims(normalized)) + + propositions: list[str] = [] + seen: set[str] = set() + for candidate in candidates: + clean = re.sub(r"\s+", " ", candidate).strip(" -") + if len(clean.split()) < 4: + continue + key = clean.lower() + if key in seen: + continue + seen.add(key) + propositions.append(clean) + if len(propositions) >= max_propositions: + break + return propositions + + +def _split_claims(text: str) -> list[str]: + return [ + part.strip() + for part in re.split(r"(?<=[.!?])\s+|;\s+", text) + if part.strip() + ] diff --git a/zsgdp/chunking/chunker.py b/zsgdp/chunking/chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..e124e614f049c1fe788f64888394f205940af12f --- /dev/null +++ b/zsgdp/chunking/chunker.py @@ -0,0 +1,735 @@ +"""Agentic chunk construction. + +The implementation keeps the high-ROI baseline live in code: +structure-aware recursive chunks around 512 tokens with overlap, plus +parent/child hierarchy. Costlier methods such as semantic, late, contextual, +vision-guided, and proposition chunking emit deterministic local candidates +with metadata showing where embedding, VLM, or LLM backends can take over. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from zsgdp.chunking.agentic import extract_atomic_propositions +from zsgdp.chunking.contextual import contextualized_text, deterministic_context_prefix +from zsgdp.chunking.hierarchy import update_section_path +from zsgdp.chunking.late import describe_late_chunking_requirement, late_chunking_available, late_chunking_metadata +from zsgdp.chunking.planner import ChunkingPlan, plan_chunking +from zsgdp.chunking.semantic import semantic_groups +from zsgdp.chunking.splitters import recursive_split, token_count +from zsgdp.chunking.vision_guided import ( + candidate_visual_regions, + figure_region_text, + mark_visual_context_required, + table_region_text, +) +from zsgdp.schema import Chunk, DocumentProfile, Element, ParsedDocument + + +@dataclass(slots=True) +class ChunkSource: + text: str + page_start: int + page_end: int + section_path: list[str] + element_ids: list[str] + source_parser: str + content_type: str = "prose" + metadata: dict = field(default_factory=dict) + + +def build_agentic_chunks(parsed: ParsedDocument, profile: DocumentProfile, config: dict | None = None) -> list[Chunk]: + config = config or {} + chunking = config.get("chunking", {}) + if not chunking.get("enabled", True): + parsed.provenance["chunking"] = {"enabled": False} + return [] + + plan = plan_chunking(parsed, profile, config) + chunks: list[Chunk] = [] + parent_chunks: list[Chunk] = [] + vision_regions: list[dict] = [] + parent_sources = _build_parent_sources(parsed, plan) + + if "fixed_token_baseline" in plan.strategies: + chunks.extend(_build_fixed_token_baseline_chunks(parsed, plan, start_index=len(chunks) + 1)) + + for source in parent_sources: + parent = _make_chunk( + parsed=parsed, + index=len(chunks) + 1, + text=source.text, + page_start=source.page_start, + page_end=source.page_end, + section_path=source.section_path, + element_ids=source.element_ids, + content_type="parent", + strategy="parent_child", + boundary_reason=source.metadata.get("boundary_reason", "section_or_page"), + source_parser=source.source_parser, + metadata={"role": "parent", **source.metadata}, + ) + chunks.append(parent) + parent_chunks.append(parent) + + child_texts = recursive_split(source.text, plan.target_tokens, plan.overlap_tokens) + for child_text in child_texts: + child = _make_chunk( + parsed=parsed, + index=len(chunks) + 1, + text=child_text, + page_start=source.page_start, + page_end=source.page_end, + section_path=source.section_path, + element_ids=source.element_ids, + content_type=source.content_type, + strategy="recursive_structure", + boundary_reason="recursive_separator", + source_parser=source.source_parser, + parent_chunk_id=parent.chunk_id, + context_prefix=_context_prefix(source.section_path, source.page_start, source.page_end) + if "contextual_retrieval" in plan.strategies + else None, + metadata={"role": "child", "parent_strategy": parent.strategy}, + ) + chunks.append(child) + parent.child_chunk_ids.append(child.chunk_id) + + if "page_level" in plan.strategies: + chunks.extend(_build_page_chunks(parsed, start_index=len(chunks) + 1)) + + if chunking.get("table_chunks", True): + for table in parsed.tables: + text = table.natural_language_rendering or table.markdown or table.html or table.caption or "" + if not text.strip(): + continue + chunks.append( + _make_chunk( + parsed=parsed, + index=len(chunks) + 1, + text=text, + page_start=min(table.page_nums or [1]), + page_end=max(table.page_nums or [1]), + section_path=[], + element_ids=[], + table_ids=[table.table_id], + content_type="table", + strategy="table_object", + boundary_reason="table_object", + source_parser=table.source_parser, + quality_score=table.confidence, + metadata={ + "markdown": table.markdown, + "html": table.html, + "natural_language_rendering": table.natural_language_rendering, + "bbox": table.bbox, + "crop_path": table.provenance.get("crop_path"), + "source_parsers": table.provenance.get("source_parsers", [table.source_parser]), + }, + ) + ) + + if chunking.get("figure_chunks", True): + for figure in parsed.figures: + text = figure.vlm_description or figure.caption or f"Figure {figure.figure_id} on page {figure.page_num}." + chunks.append( + _make_chunk( + parsed=parsed, + index=len(chunks) + 1, + text=text, + page_start=figure.page_num, + page_end=figure.page_num, + section_path=[], + element_ids=[], + figure_ids=[figure.figure_id], + content_type="figure", + strategy="figure_object", + boundary_reason="figure_object", + source_parser=figure.source_parser, + quality_score=figure.confidence, + requires_visual_context=True, + metadata={ + "image_path": figure.image_path, + "bbox": figure.bbox, + "chart_data": figure.chart_data, + }, + ) + ) + + if "semantic_chunking" in plan.strategies: + chunks.extend(_build_semantic_chunks(parsed, parent_sources, plan, config, start_index=len(chunks) + 1)) + + if "late_chunking" in plan.strategies: + chunks.extend(_build_late_chunks(parsed, parent_chunks, plan, config, start_index=len(chunks) + 1)) + + if "agentic_proposition_chunking" in plan.strategies: + chunks.extend(_build_proposition_chunks(parsed, parent_sources, config, start_index=len(chunks) + 1)) + + if "vision_guided" in plan.strategies: + mark_visual_context_required(chunks) + vision_regions = candidate_visual_regions(parsed) + chunks.extend(_build_vision_guided_chunks(parsed, start_index=len(chunks) + 1)) + + if "contextual_retrieval" in plan.strategies: + chunks.extend(_build_contextual_retrieval_chunks(parsed, chunks, start_index=len(chunks) + 1)) + + parsed.provenance["chunking"] = { + "enabled": True, + "plan": plan.to_dict(), + "chunk_count": len(chunks), + "parent_chunk_count": len(parent_chunks), + "fixed_token_baseline_count": sum(1 for chunk in chunks if chunk.strategy == "fixed_token_baseline"), + "semantic_chunk_count": sum(1 for chunk in chunks if chunk.strategy == "semantic"), + "late_chunk_count": sum(1 for chunk in chunks if chunk.strategy == "late"), + "contextual_retrieval_chunk_count": sum(1 for chunk in chunks if chunk.strategy == "contextual_retrieval"), + "vision_guided_chunk_count": sum(1 for chunk in chunks if chunk.strategy == "vision_guided"), + "agentic_proposition_chunk_count": sum(1 for chunk in chunks if chunk.strategy == "agentic_proposition"), + "vision_regions": vision_regions, + } + return chunks + + +def _build_fixed_token_baseline_chunks(parsed: ParsedDocument, plan: ChunkingPlan, start_index: int) -> list[Chunk]: + token_records: list[tuple[str, int, str, str]] = [] + elements = sorted(parsed.elements, key=lambda item: (item.page_num, item.reading_order or 0, item.element_id)) + for element in elements: + content = element.content().strip() + if not content: + continue + for token in content.split(): + token_records.append((token, element.page_num, element.element_id, element.source_parser)) + + if not token_records: + return [] + + chunks: list[Chunk] = [] + step = max(plan.target_tokens - plan.overlap_tokens, 1) + start = 0 + while start < len(token_records): + end = min(start + plan.target_tokens, len(token_records)) + window = token_records[start:end] + pages = [record[1] for record in window] + element_ids = _dedupe_in_order([record[2] for record in window]) + source_parser = _source_parser_for_window(window) + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=" ".join(record[0] for record in window), + page_start=min(pages), + page_end=max(pages), + section_path=[], + element_ids=element_ids, + content_type="baseline", + strategy="fixed_token_baseline", + boundary_reason="fixed_token_window", + source_parser=source_parser, + metadata={ + "role": "fixed_token_baseline", + "target_tokens": plan.target_tokens, + "overlap_tokens": plan.overlap_tokens, + "token_start": start, + "token_end": end, + }, + ) + ) + if end == len(token_records): + break + start += step + return chunks + + +def _build_parent_sources(parsed: ParsedDocument, plan: ChunkingPlan) -> list[ChunkSource]: + sources: list[ChunkSource] = [] + section_path: list[str] = [] + buffer: list[str] = [] + element_ids: list[str] = [] + page_start: int | None = None + page_end: int | None = None + source_parser = "unknown" + + def flush(reason: str) -> None: + nonlocal buffer, element_ids, page_start, page_end, source_parser + if not buffer or page_start is None or page_end is None: + return + text = "\n\n".join(buffer).strip() + if text: + sources.append( + ChunkSource( + text=text, + page_start=page_start, + page_end=page_end, + section_path=list(section_path), + element_ids=list(element_ids), + source_parser=source_parser, + metadata={"boundary_reason": reason}, + ) + ) + buffer = [] + element_ids = [] + page_start = None + page_end = None + source_parser = "unknown" + + elements = sorted(parsed.elements, key=lambda item: (item.page_num, item.reading_order or 0)) + for element in elements: + content = element.content().strip() + if not content: + continue + next_section_path = update_section_path(section_path, element) + starts_new_section = next_section_path != section_path and bool(buffer) + if starts_new_section: + flush("heading_boundary") + section_path = next_section_path + + if page_start is None: + page_start = element.page_num + page_end = element.page_num + source_parser = element.source_parser + + candidate = "\n\n".join(buffer + [content]) + if token_count(candidate) > plan.parent_target_tokens and buffer: + flush("parent_token_budget") + page_start = element.page_num + page_end = element.page_num + source_parser = element.source_parser + + buffer.append(content) + element_ids.append(element.element_id) + + flush("document_end") + return sources + + +def _build_page_chunks(parsed: ParsedDocument, start_index: int) -> list[Chunk]: + chunks: list[Chunk] = [] + by_page: dict[int, list[Element]] = {} + for element in parsed.elements: + by_page.setdefault(element.page_num, []).append(element) + + for page_num in sorted(by_page): + elements = sorted(by_page[page_num], key=lambda item: item.reading_order or 0) + text = "\n\n".join(element.content().strip() for element in elements if element.content().strip()) + if not text: + continue + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=text, + page_start=page_num, + page_end=page_num, + section_path=[], + element_ids=[element.element_id for element in elements], + content_type="page", + strategy="page_level", + boundary_reason="page_boundary", + source_parser=elements[0].source_parser if elements else "unknown", + metadata={"role": "page_level_baseline"}, + ) + ) + return chunks + + +def _build_semantic_chunks( + parsed: ParsedDocument, + sources: list[ChunkSource], + plan: ChunkingPlan, + config: dict, + start_index: int, +) -> list[Chunk]: + chunking = config.get("chunking", {}) + threshold = float(chunking.get("semantic_similarity_threshold", 0.18)) + chunks: list[Chunk] = [] + for source in sources: + groups, scores, boundaries = semantic_groups(source.text, threshold=threshold, max_tokens=plan.target_tokens) + for group_index, (sentences, sentence_start, sentence_end) in enumerate(groups): + text = " ".join(sentences).strip() + if not text: + continue + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=text, + page_start=source.page_start, + page_end=source.page_end, + section_path=source.section_path, + element_ids=source.element_ids, + content_type=source.content_type, + strategy="semantic", + boundary_reason="semantic_shift", + source_parser=source.source_parser, + context_prefix=_context_prefix(source.section_path, source.page_start, source.page_end), + metadata={ + "role": "semantic_candidate", + "execution_mode": "lexical_similarity_proxy", + "similarity_threshold": threshold, + "adjacent_similarity_scores": [round(score, 4) for score in scores], + "boundary_indexes": boundaries, + "group_index": group_index, + "sentence_start": sentence_start, + "sentence_end": sentence_end, + "source_boundary_reason": source.metadata.get("boundary_reason"), + }, + ) + ) + return chunks + + +def _build_late_chunks( + parsed: ParsedDocument, + parent_chunks: list[Chunk], + plan: ChunkingPlan, + config: dict, + start_index: int, +) -> list[Chunk]: + chunks: list[Chunk] = [] + backend_metadata = late_chunking_metadata(config) + execution_mode = "planned_token_pooling" if late_chunking_available(config) else "requires_long_context_embedding_backend" + for parent in parent_chunks: + spans = recursive_split(parent.text, plan.target_tokens, plan.overlap_tokens) + for span_index, span in enumerate(spans): + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=span, + page_start=parent.page_start, + page_end=parent.page_end, + section_path=parent.section_path, + element_ids=parent.element_ids, + content_type="late_span", + strategy="late", + boundary_reason="late_span_pooling", + source_parser=parent.source_parser, + parent_chunk_id=parent.chunk_id, + context_prefix=parent.context_prefix, + metadata={ + "role": "late_chunk_span", + "execution_mode": execution_mode, + "source_parent_chunk_id": parent.chunk_id, + "span_index": span_index, + "requirement": describe_late_chunking_requirement(), + **backend_metadata, + }, + ) + ) + return chunks + + +def _build_proposition_chunks( + parsed: ParsedDocument, + sources: list[ChunkSource], + config: dict, + start_index: int, +) -> list[Chunk]: + chunking = config.get("chunking", {}) + max_per_source = int(chunking.get("max_propositions_per_source", 8)) + max_total = int(chunking.get("max_proposition_chunks", 64)) + chunks: list[Chunk] = [] + + for source in sources: + for proposition_index, proposition in enumerate( + extract_atomic_propositions(source.text, max_propositions=max_per_source) + ): + if len(chunks) >= max_total: + return chunks + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=proposition, + page_start=source.page_start, + page_end=source.page_end, + section_path=source.section_path, + element_ids=source.element_ids, + content_type="proposition", + strategy="agentic_proposition", + boundary_reason="agentic_proposition", + source_parser=source.source_parser, + context_prefix=_context_prefix(source.section_path, source.page_start, source.page_end), + metadata={ + "role": "atomic_proposition", + "execution_mode": "deterministic_sentence_proxy", + "proposition_index": proposition_index, + "source_boundary_reason": source.metadata.get("boundary_reason"), + "source_element_ids": source.element_ids, + "llm_prompt_available": True, + }, + ) + ) + + for table in parsed.tables: + if len(chunks) >= max_total: + return chunks + table_text = table.natural_language_rendering or table.caption or table.markdown or "" + for proposition_index, proposition in enumerate( + extract_atomic_propositions(table_text, max_propositions=max_per_source) + ): + if len(chunks) >= max_total: + return chunks + page_nums = table.page_nums or [1] + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=proposition, + page_start=min(page_nums), + page_end=max(page_nums), + section_path=[], + element_ids=[], + table_ids=[table.table_id], + content_type="proposition", + strategy="agentic_proposition", + boundary_reason="agentic_table_proposition", + source_parser=_object_source_parser(table.source_parser, _fallback_source_parser(parsed)), + context_prefix=_context_prefix([], min(page_nums), max(page_nums)), + metadata={ + "role": "atomic_table_proposition", + "execution_mode": "deterministic_sentence_proxy", + "proposition_index": proposition_index, + "source_table_id": table.table_id, + "llm_prompt_available": True, + }, + ) + ) + + for figure in parsed.figures: + if len(chunks) >= max_total: + return chunks + figure_text = figure.vlm_description or figure.caption or "" + for proposition_index, proposition in enumerate( + extract_atomic_propositions(figure_text, max_propositions=max_per_source) + ): + if len(chunks) >= max_total: + return chunks + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=proposition, + page_start=figure.page_num, + page_end=figure.page_num, + section_path=[], + element_ids=[], + figure_ids=[figure.figure_id], + content_type="proposition", + strategy="agentic_proposition", + boundary_reason="agentic_figure_proposition", + source_parser=_object_source_parser(figure.source_parser, _fallback_source_parser(parsed)), + context_prefix=_context_prefix([], figure.page_num, figure.page_num), + metadata={ + "role": "atomic_figure_proposition", + "execution_mode": "deterministic_sentence_proxy", + "proposition_index": proposition_index, + "source_figure_id": figure.figure_id, + "llm_prompt_available": True, + }, + ) + ) + return chunks + + +def _build_vision_guided_chunks(parsed: ParsedDocument, start_index: int) -> list[Chunk]: + chunks: list[Chunk] = [] + fallback_parser = _fallback_source_parser(parsed) + for table in parsed.tables: + page_nums = table.page_nums or [1] + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=table_region_text(table), + page_start=min(page_nums), + page_end=max(page_nums), + section_path=[], + element_ids=[], + table_ids=[table.table_id], + content_type="visual_region", + strategy="vision_guided", + boundary_reason="visual_region", + source_parser=_object_source_parser(table.source_parser, fallback_parser), + requires_visual_context=True, + metadata={ + "role": "vision_guided_region", + "execution_mode": "layout_region_proxy", + "region_id": table.table_id, + "region_type": "table", + "bbox": table.bbox, + "crop_path": table.provenance.get("crop_path"), + "source_strategy": "table_object", + }, + ) + ) + for figure in parsed.figures: + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=figure_region_text(figure), + page_start=figure.page_num, + page_end=figure.page_num, + section_path=[], + element_ids=[], + figure_ids=[figure.figure_id], + content_type="visual_region", + strategy="vision_guided", + boundary_reason="visual_region", + source_parser=_object_source_parser(figure.source_parser, fallback_parser), + requires_visual_context=True, + metadata={ + "role": "vision_guided_region", + "execution_mode": "layout_region_proxy", + "region_id": figure.figure_id, + "region_type": "figure", + "bbox": figure.bbox, + "image_path": figure.image_path, + "chart_data": figure.chart_data, + "source_strategy": "figure_object", + }, + ) + ) + return chunks + + +def _build_contextual_retrieval_chunks( + parsed: ParsedDocument, + source_chunks: list[Chunk], + start_index: int, +) -> list[Chunk]: + chunks: list[Chunk] = [] + contextual_source_strategies = { + "recursive_structure", + "page_level", + "table_object", + "figure_object", + "semantic", + "vision_guided", + "agentic_proposition", + } + for source in list(source_chunks): + if source.strategy not in contextual_source_strategies or not source.text.strip(): + continue + prefix = source.context_prefix or deterministic_context_prefix(source) + source.context_prefix = prefix + source.metadata.setdefault("contextual_prefix_mode", "deterministic") + chunks.append( + _make_chunk( + parsed=parsed, + index=start_index + len(chunks), + text=contextualized_text(source), + page_start=source.page_start, + page_end=source.page_end, + section_path=source.section_path, + element_ids=source.element_ids, + table_ids=source.table_ids, + figure_ids=source.figure_ids, + content_type=source.content_type, + strategy="contextual_retrieval", + boundary_reason="contextual_prefix", + source_parser=source.source_parser, + parent_chunk_id=source.parent_chunk_id, + context_prefix=prefix, + quality_score=source.quality_score, + requires_visual_context=source.requires_visual_context, + metadata={ + "role": "contextual_retrieval_candidate", + "execution_mode": "deterministic_context_prefix", + "source_chunk_id": source.chunk_id, + "source_strategy": source.strategy, + "context_generation_mode": "deterministic", + }, + ) + ) + return chunks + + +def _make_chunk( + *, + parsed: ParsedDocument, + index: int, + text: str, + page_start: int, + page_end: int, + section_path: list[str], + element_ids: list[str], + content_type: str, + strategy: str, + boundary_reason: str, + source_parser: str, + table_ids: list[str] | None = None, + figure_ids: list[str] | None = None, + parent_chunk_id: str | None = None, + context_prefix: str | None = None, + quality_score: float | None = None, + requires_visual_context: bool = False, + metadata: dict | None = None, +) -> Chunk: + clean_text = text.strip() + return Chunk( + chunk_id=f"c{index}", + doc_id=parsed.doc_id, + page_start=page_start, + page_end=page_end, + section_path=section_path, + content_type=content_type, + text=clean_text, + element_ids=element_ids, + table_ids=table_ids or [], + figure_ids=figure_ids or [], + parent_chunk_id=parent_chunk_id, + strategy=strategy, + boundary_reason=boundary_reason, + token_count=token_count(clean_text), + source_parser=source_parser, + quality_score=parsed.quality_report.score if quality_score is None else quality_score, + requires_visual_context=requires_visual_context, + context_prefix=context_prefix, + metadata=metadata or {}, + ) + + +def _context_prefix(section_path: list[str], page_start: int, page_end: int) -> str: + section = " > ".join(section_path) if section_path else "Document" + page_text = f"page {page_start}" if page_start == page_end else f"pages {page_start}-{page_end}" + return f"This chunk is from {section}, {page_text}." + + +def _dedupe_in_order(items: list[str]) -> list[str]: + seen: set[str] = set() + out: list[str] = [] + for item in items: + if item not in seen: + out.append(item) + seen.add(item) + return out + + +def _source_parser_for_window(window: list[tuple[str, int, str, str]]) -> str: + parsers = _dedupe_in_order([record[3] for record in window if record[3]]) + if not parsers: + return "unknown" + if len(parsers) == 1: + return parsers[0] + return "mixed" + + +def _fallback_source_parser(parsed: ParsedDocument) -> str: + parsers = _dedupe_in_order( + [ + element.source_parser + for element in parsed.elements + if element.source_parser and element.source_parser != "unknown" + ] + ) + if not parsers: + return "unknown" + if len(parsers) == 1: + return parsers[0] + return "mixed" + + +def _object_source_parser(source_parser: str, fallback: str) -> str: + return fallback if not source_parser or source_parser == "unknown" else source_parser diff --git a/zsgdp/chunking/contextual.py b/zsgdp/chunking/contextual.py new file mode 100644 index 0000000000000000000000000000000000000000..2849fa8752d89e5532f0bfa03f66a8928b8ff2e1 --- /dev/null +++ b/zsgdp/chunking/contextual.py @@ -0,0 +1,28 @@ +"""Contextual retrieval chunking hooks.""" + +from __future__ import annotations + +from zsgdp.schema import Chunk + + +def deterministic_context_prefix(chunk: Chunk) -> str: + section = " > ".join(chunk.section_path) if chunk.section_path else "Document" + page_text = f"page {chunk.page_start}" if chunk.page_start == chunk.page_end else f"pages {chunk.page_start}-{chunk.page_end}" + content = chunk.content_type.replace("_", " ") + return f"This {content} chunk is from {section}, {page_text}." + + +def add_context_prefixes(chunks: list[Chunk]) -> list[Chunk]: + """Add deterministic context prefixes when an LLM backend is not configured.""" + + for chunk in chunks: + if chunk.context_prefix: + continue + chunk.context_prefix = deterministic_context_prefix(chunk) + chunk.metadata["contextual_prefix_mode"] = "deterministic" + return chunks + + +def contextualized_text(chunk: Chunk) -> str: + prefix = chunk.context_prefix or deterministic_context_prefix(chunk) + return f"{prefix}\n\n{chunk.text.strip()}".strip() diff --git a/zsgdp/chunking/figure_chunks.py b/zsgdp/chunking/figure_chunks.py new file mode 100644 index 0000000000000000000000000000000000000000..2edf49343c06c9b8cd807541b85634c7d0e0d103 --- /dev/null +++ b/zsgdp/chunking/figure_chunks.py @@ -0,0 +1,9 @@ +"""Figure chunk helpers.""" + +from __future__ import annotations + +from zsgdp.schema import FigureObject + + +def figure_chunk_text(figure: FigureObject) -> str: + return (figure.vlm_description or figure.caption or "").strip() diff --git a/zsgdp/chunking/hierarchy.py b/zsgdp/chunking/hierarchy.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa886b6934814ac3594de522332c940101f9ee4 --- /dev/null +++ b/zsgdp/chunking/hierarchy.py @@ -0,0 +1,17 @@ +"""Section hierarchy helpers for chunking.""" + +from __future__ import annotations + +from zsgdp.schema import Element + + +def update_section_path(section_path: list[str], element: Element) -> list[str]: + content = element.content().strip() + if not content: + return section_path + if element.type == "title": + return [content.lstrip("# ").strip()] + if element.type == "heading": + normalized = content.lstrip("# ").strip() + return (section_path[:1] or []) + [normalized] + return section_path diff --git a/zsgdp/chunking/late.py b/zsgdp/chunking/late.py new file mode 100644 index 0000000000000000000000000000000000000000..0e91db315ef0673166eb137f9838da05d73f9774 --- /dev/null +++ b/zsgdp/chunking/late.py @@ -0,0 +1,32 @@ +"""Late chunking hooks.""" + +from __future__ import annotations + + +def late_chunking_available(config: dict) -> bool: + embedding = _embedding_config(config) + if bool(embedding.get("supports_late_chunking", False)): + return True + model_id = str(embedding.get("model_id", "")).lower() + return "jina-embeddings-v3" in model_id or "jina-embeddings-v4" in model_id + + +def describe_late_chunking_requirement() -> str: + return "Late chunking requires a long-context embedding model that exposes token-level representations." + + +def late_chunking_metadata(config: dict) -> dict: + embedding = _embedding_config(config) + return { + "embedding_model": embedding.get("model_id"), + "embedding_task": embedding.get("task"), + "supports_late_chunking": late_chunking_available(config), + "requires_token_level_embeddings": True, + } + + +def _embedding_config(config: dict) -> dict: + embedding = config.get("embedding", {}) + if embedding: + return embedding + return config.get("gpu", {}).get("models", {}).get("embedding", {}) diff --git a/zsgdp/chunking/planner.py b/zsgdp/chunking/planner.py new file mode 100644 index 0000000000000000000000000000000000000000..8843b71e5b40612d944bf35d72b2cd419c0aeb28 --- /dev/null +++ b/zsgdp/chunking/planner.py @@ -0,0 +1,98 @@ +"""Agentic chunking planner.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from zsgdp.schema import DocumentProfile, ParsedDocument +from zsgdp.utils import to_plain_data + + +@dataclass(slots=True) +class ChunkingPlan: + strategies: list[str] + target_tokens: int + overlap_tokens: int + parent_target_tokens: int + reasons: list[str] = field(default_factory=list) + deferred_strategies: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +def plan_chunking(parsed: ParsedDocument, profile: DocumentProfile, config: dict[str, Any]) -> ChunkingPlan: + chunking = config.get("chunking", {}) + target_tokens = int(chunking.get("target_tokens", 512)) + overlap_tokens = max(0, int(target_tokens * float(chunking.get("overlap_ratio", 0.15)))) + parent_target_tokens = int(chunking.get("parent_target_tokens", 1600)) + strategies: list[str] = [] + reasons: list[str] = [] + deferred: list[str] = [] + + strategies.append("fixed_token_baseline") + reasons.append("Keep a simple fixed/token baseline for measurement.") + + baseline = str(chunking.get("baseline_strategy", "recursive_structure")) + if baseline not in strategies: + strategies.append(baseline) + reasons.append("Recursive structure-aware chunking is the default production baseline.") + + if _has_structure(parsed): + if "structure_aware" not in strategies: + strategies.append("structure_aware") + reasons.append("Parsed headings/page boundaries are available.") + + if chunking.get("parent_child", True): + strategies.append("parent_child") + reasons.append("Use child chunks for precise retrieval and parent chunks for context.") + + if profile.file_type == "pdf" and profile.page_count > 1 and chunking.get("page_level_for_paginated_docs", True): + strategies.append("page_level") + reasons.append("Page-level chunking is retained for paginated-document benchmarking.") + + if chunking.get("contextual_prefix", False) or chunking.get("contextual_retrieval", False): + strategies.append("contextual_retrieval") + reasons.append("Contextual retrieval prefixing is enabled by config.") + + optional = { + "semantic_chunking": "embedding-aware semantic boundaries require embeddings.", + "late_chunking": "late chunking requires long-context embedding token states.", + "vision_guided": "vision-guided chunking requires rendered page regions or a VLM.", + "agentic_proposition_chunking": "agentic/proposition chunking requires an LLM/VLM backend.", + } + for strategy, reason in optional.items(): + if chunking.get(strategy, False): + strategies.append(strategy) + reasons.append(f"{strategy} enabled by config.") + else: + deferred.append(f"{strategy}: {reason}") + + return ChunkingPlan( + strategies=_dedupe(strategies), + target_tokens=target_tokens, + overlap_tokens=overlap_tokens, + parent_target_tokens=parent_target_tokens, + reasons=reasons, + deferred_strategies=deferred, + metadata={ + "planner": chunking.get("planner", "agentic"), + "strategy_ladder": chunking.get("strategy_ladder", []), + }, + ) + + +def _has_structure(parsed: ParsedDocument) -> bool: + return any(element.type in {"title", "heading"} for element in parsed.elements) + + +def _dedupe(items: list[str]) -> list[str]: + seen: set[str] = set() + out: list[str] = [] + for item in items: + if item not in seen: + out.append(item) + seen.add(item) + return out diff --git a/zsgdp/chunking/semantic.py b/zsgdp/chunking/semantic.py new file mode 100644 index 0000000000000000000000000000000000000000..5f65d58addb1363246797b0c6724ba068d3fb5d8 --- /dev/null +++ b/zsgdp/chunking/semantic.py @@ -0,0 +1,108 @@ +"""Embedding-aware semantic chunking hooks. + +The production path can replace these lexical scores with embedding cosine +similarities. Keeping a deterministic local fallback makes the strategy +observable in tests, benchmarks, and artifact exports without requiring a model +download. +""" + +from __future__ import annotations + +import re + + +def semantic_boundary_plan(sentences: list[str], similarity_scores: list[float], threshold: float = 0.72) -> list[int]: + """Return boundary indexes where adjacent sentence similarity drops.""" + + if len(similarity_scores) != max(len(sentences) - 1, 0): + raise ValueError("similarity_scores must have one score per adjacent sentence pair.") + return [index + 1 for index, score in enumerate(similarity_scores) if score < threshold] + + +def split_sentences(text: str) -> list[str]: + """Split text into sentence-like units without external NLP packages.""" + + normalized = re.sub(r"\s+", " ", text).strip() + if not normalized: + return [] + parts = re.split(r"(?<=[.!?])\s+|(?:\s+-\s+)", normalized) + return [part.strip() for part in parts if part.strip()] + + +def lexical_similarity_scores(sentences: list[str]) -> list[float]: + """Return adjacent Jaccard scores as a cheap embedding-similarity proxy.""" + + scores: list[float] = [] + for left, right in zip(sentences, sentences[1:]): + left_terms = _terms(left) + right_terms = _terms(right) + if not left_terms and not right_terms: + scores.append(1.0) + continue + union = left_terms | right_terms + scores.append(len(left_terms & right_terms) / len(union) if union else 0.0) + return scores + + +def semantic_groups( + text: str, + *, + threshold: float = 0.18, + max_tokens: int = 512, +) -> tuple[list[tuple[list[str], int, int]], list[float], list[int]]: + """Group sentence-like units at lexical topic-shift boundaries.""" + + sentences = split_sentences(text) + if not sentences: + return [], [], [] + if len(sentences) == 1: + return [([sentences[0]], 0, 1)], [], [] + + scores = lexical_similarity_scores(sentences) + boundaries = semantic_boundary_plan(sentences, scores, threshold) + boundary_set = set(boundaries) + groups: list[tuple[list[str], int, int]] = [] + current: list[str] = [] + start_index = 0 + for index, sentence in enumerate(sentences): + would_exceed = _token_count(" ".join(current + [sentence])) > max_tokens + should_split = index in boundary_set or (would_exceed and current) + if should_split: + groups.append((current, start_index, index)) + current = [] + start_index = index + current.append(sentence) + if current: + groups.append((current, start_index, len(sentences))) + return groups, scores, boundaries + + +def _terms(text: str) -> set[str]: + return { + token + for token in re.findall(r"[a-zA-Z0-9]+", text.lower()) + if len(token) > 2 and token not in _STOPWORDS + } + + +def _token_count(text: str) -> int: + return len(text.split()) + + +_STOPWORDS = { + "and", + "are", + "but", + "for", + "from", + "has", + "have", + "into", + "not", + "the", + "this", + "that", + "was", + "were", + "with", +} diff --git a/zsgdp/chunking/splitters.py b/zsgdp/chunking/splitters.py new file mode 100644 index 0000000000000000000000000000000000000000..a21b020faa159a03bded64843029f49251bd93b7 --- /dev/null +++ b/zsgdp/chunking/splitters.py @@ -0,0 +1,65 @@ +"""Small local splitters used as deterministic baselines.""" + +from __future__ import annotations + + +def token_count(text: str) -> int: + return len(text.split()) + + +def fixed_token_split(text: str, target_tokens: int, overlap_tokens: int) -> list[str]: + tokens = text.split() + if not tokens: + return [] + if len(tokens) <= target_tokens: + return [" ".join(tokens)] + + chunks: list[str] = [] + step = max(target_tokens - overlap_tokens, 1) + start = 0 + while start < len(tokens): + end = min(start + target_tokens, len(tokens)) + chunks.append(" ".join(tokens[start:end])) + if end == len(tokens): + break + start += step + return chunks + + +def recursive_split(text: str, target_tokens: int, overlap_tokens: int) -> list[str]: + if token_count(text) <= target_tokens: + return [text.strip()] if text.strip() else [] + + pieces = _split_by_separators(text, ["\n\n", "\n", ". ", " "]) + chunks: list[str] = [] + current: list[str] = [] + for piece in pieces: + candidate = "\n\n".join(current + [piece]).strip() + if token_count(candidate) <= target_tokens: + current.append(piece) + continue + if current: + chunks.extend(fixed_token_split("\n\n".join(current), target_tokens, overlap_tokens)) + current = [piece] + if current: + chunks.extend(fixed_token_split("\n\n".join(current), target_tokens, overlap_tokens)) + return [chunk for chunk in chunks if chunk.strip()] + + +def _split_by_separators(text: str, separators: list[str]) -> list[str]: + if not separators: + return [text] + separator = separators[0] + pieces = [piece.strip() for piece in text.split(separator) if piece.strip()] + if all(token_count(piece) <= 512 for piece in pieces): + if separator == ". ": + return [piece + "." if not piece.endswith(".") else piece for piece in pieces] + return pieces + + out: list[str] = [] + for piece in pieces: + if token_count(piece) <= 512: + out.append(piece) + else: + out.extend(_split_by_separators(piece, separators[1:])) + return out diff --git a/zsgdp/chunking/table_chunks.py b/zsgdp/chunking/table_chunks.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4e99d53e808ba92737c0b7752b3d11d0e5841b --- /dev/null +++ b/zsgdp/chunking/table_chunks.py @@ -0,0 +1,9 @@ +"""Table chunk helpers.""" + +from __future__ import annotations + +from zsgdp.schema import TableObject + + +def table_chunk_text(table: TableObject) -> str: + return (table.natural_language_rendering or table.markdown or table.html or table.caption or "").strip() diff --git a/zsgdp/chunking/vision_guided.py b/zsgdp/chunking/vision_guided.py new file mode 100644 index 0000000000000000000000000000000000000000..4470c6588ee74ee7681e784c5642762052d1577c --- /dev/null +++ b/zsgdp/chunking/vision_guided.py @@ -0,0 +1,59 @@ +"""Vision-guided chunking hooks.""" + +from __future__ import annotations + +from zsgdp.schema import Chunk, ParsedDocument, TableObject, FigureObject + + +def candidate_visual_regions(parsed: ParsedDocument) -> list[dict]: + """Return table/figure regions that a future VLM chunker should inspect.""" + + regions: list[dict] = [] + for table in parsed.tables: + regions.append( + { + "region_id": table.table_id, + "type": "table", + "page_nums": table.page_nums, + "bbox": table.bbox, + "source_parser": table.source_parser, + } + ) + for figure in parsed.figures: + regions.append( + { + "region_id": figure.figure_id, + "type": "figure", + "page_nums": [figure.page_num], + "bbox": figure.bbox, + "source_parser": figure.source_parser, + } + ) + return regions + + +def mark_visual_context_required(chunks: list[Chunk]) -> list[Chunk]: + for chunk in chunks: + if chunk.content_type in {"table", "figure"}: + chunk.requires_visual_context = True + return chunks + + +def table_region_text(table: TableObject) -> str: + page_text = _page_text(table.page_nums) + body = table.natural_language_rendering or table.markdown or table.html or table.caption or "" + prefix = f"Visual table region {table.table_id} on {page_text}." + return f"{prefix}\n\n{body}".strip() + + +def figure_region_text(figure: FigureObject) -> str: + body = figure.vlm_description or figure.caption or f"Figure {figure.figure_id}." + return f"Visual figure region {figure.figure_id} on page {figure.page_num}.\n\n{body}".strip() + + +def _page_text(page_nums: list[int]) -> str: + if not page_nums: + return "unknown pages" + if min(page_nums) == max(page_nums): + return f"page {page_nums[0]}" + return f"pages {min(page_nums)}-{max(page_nums)}" diff --git a/zsgdp/cli.py b/zsgdp/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..753ac2af44b9fce633aee0574ce0a90b36af0207 --- /dev/null +++ b/zsgdp/cli.py @@ -0,0 +1,646 @@ +"""Command-line interface.""" + +from __future__ import annotations + +import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +import json +from pathlib import Path +import shutil +from typing import Sequence + +from zsgdp.artifacts import validate_artifact_manifest +from zsgdp.benchmarks.ablation_runner import run_parser_ablations +from zsgdp.benchmarks.cross_dataset import combine_benchmark_summaries, write_cross_dataset_outputs +from zsgdp.benchmarks.parser_quality import run_parser_benchmark +from zsgdp.config import load_env_file +from zsgdp.logging_config import configure_logging +from zsgdp.preflight import format_failures, format_summary, run_preflight +from zsgdp.config import load_config +from zsgdp.deployment import check_huggingface_space +from zsgdp.gpu import collect_gpu_runtime_status, run_gpu_task_manifest +from zsgdp.parsers.registry import get_parser, parser_names +from zsgdp.pipeline import parse_document +from zsgdp.profiling import profile_document +from zsgdp.utils import dumps_json, write_json + + +def _epilog(text: str) -> str: + """Format a multi-line examples block for argparse epilog. + + Dedents the source-indented triple-quoted string so the rendered help + output isn't pushed to the right by however far the call site happens + to be nested. + """ + + import textwrap + + dedented = textwrap.dedent(text).strip("\n") + return "Examples:\n" + "\n".join(f" {line}" if line else "" for line in dedented.splitlines()) + + +def main(argv: Sequence[str] | None = None) -> int: + load_env_file() + configure_logging() + parser = argparse.ArgumentParser( + prog="zsgdp", + description="Zero-shot GPU document parser control plane.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + zsgdp parse --input ./docs/sample.md --output ./out/sample + zsgdp benchmark --input ./docs --output ./bench + zsgdp preflight --root . + + See README.md and docs/space_smoke.md for end-to-end workflows. + """ + ), + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + parse_parser = subparsers.add_parser( + "parse", + help="Parse one document.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + zsgdp parse --input ./docs/report.pdf --output ./out/report + zsgdp parse --input ./docs/report.pdf --output ./out/report --config configs/docling.yaml + zsgdp parse --input ./docs/report.pdf --output ./out/report --parser docling --parser pymupdf + """ + ), + ) + parse_parser.add_argument("--input", required=True, help="Input document path.") + parse_parser.add_argument("--output", required=True, help="Output directory.") + parse_parser.add_argument("--config", help="Optional YAML config path.") + parse_parser.add_argument("--parser", action="append", dest="parsers", help="Force a parser. Can be repeated.") + parse_parser.add_argument("--parsers", nargs="+", dest="parser_list", help="Force one or more parsers.") + + folder_parser = subparsers.add_parser( + "parse-folder", + help="Parse every file in a folder.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + zsgdp parse-folder --input ./docs --output ./parsed --workers 4 + zsgdp parse-folder --input ./docs --output ./parsed --workers 8 --gpu-workers 2 --config configs/docling.yaml + """ + ), + ) + folder_parser.add_argument("--input", required=True, help="Input folder.") + folder_parser.add_argument("--output", required=True, help="Output folder.") + folder_parser.add_argument("--config", help="Optional YAML config path.") + folder_parser.add_argument("--workers", type=int, default=1, help="Number of documents to parse concurrently.") + folder_parser.add_argument( + "--gpu-workers", + type=int, + default=0, + help="Record reserved GPU worker slots for downstream task execution; document parsing uses --workers.", + ) + folder_parser.add_argument("--parser", action="append", dest="parsers", help="Force a parser. Can be repeated.") + folder_parser.add_argument("--parsers", nargs="+", dest="parser_list", help="Force one or more parsers.") + + profile_parser = subparsers.add_parser("profile", help="Profile a document without parsing.") + profile_parser.add_argument("--input", required=True, help="Input document path.") + + gpu_parser = subparsers.add_parser("gpu-status", help="Print GPU/model runtime status.") + gpu_parser.add_argument("--config", help="Optional YAML config path.") + + space_parser = subparsers.add_parser( + "space-check", + help="Check Hugging Face Space deployment readiness.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + zsgdp space-check --root . + zsgdp space-check --root . --output ./space_report.json + """ + ), + ) + space_parser.add_argument("--root", default=".", help="Repository root to check.") + space_parser.add_argument("--config", help="Optional YAML config path.") + space_parser.add_argument("--output", help="Optional JSON readiness report path.") + + task_parser = subparsers.add_parser( + "run-gpu-tasks", + help="Validate and optionally execute a gpu_tasks.jsonl manifest.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + # Dry-run preflight (default — no model invoked): + zsgdp run-gpu-tasks --input ./out/report --output ./out/report/gpu_task_report.json + + # Live execution against the configured backend: + zsgdp run-gpu-tasks --input ./out/report --output ./out/report/gpu_task_report.json --execute + """ + ), + ) + task_parser.add_argument("--input", required=True, help="Parsed output directory or gpu_tasks.jsonl path.") + task_parser.add_argument("--output", required=True, help="Execution report JSON path.") + task_parser.add_argument("--config", help="Optional YAML config path.") + task_parser.add_argument("--execute", action="store_true", help="Execute ready tasks with the configured GPU backend.") + + subparsers.add_parser("parsers", help="List parser adapters and availability.") + + bench_parser = subparsers.add_parser( + "benchmark", + help="Run a parser/chunking benchmark over a folder.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + # Custom corpus, no GT (still emits all GT-free metrics): + zsgdp benchmark --input ./docs --output ./bench + + # OmniDocBench checkout (also runs layout F1 / table structure / formula CER): + zsgdp benchmark --input ./omnidocbench --dataset omnidocbench --output ./bench/omni + + # DocLayNet checkout (layout F1 only — DocLayNet has no table/formula GT): + zsgdp benchmark --input ./doclaynet --dataset doclaynet --output ./bench/doclay + + # Force a specific parser combo: + zsgdp benchmark --input ./docs --output ./bench --parser docling --parser pymupdf + """ + ), + ) + bench_parser.add_argument("--input", required=False, help="Input folder of documents.") + bench_parser.add_argument( + "--dataset", + required=False, + default="custom_folder", + help="Dataset loader name (custom_folder, omnidocbench, doclaynet). 'custom' is accepted as an alias.", + ) + bench_parser.add_argument("--output", required=False, default="./benchmarks/results") + bench_parser.add_argument("--config", help="Optional YAML config path.") + bench_parser.add_argument("--parser", action="append", dest="parsers", help="Force a parser. Can be repeated.") + bench_parser.add_argument("--parsers", nargs="+", dest="parser_list", help="Force one or more parsers.") + + ablate_parser = subparsers.add_parser( + "benchmark-ablate", + help="Run the benchmark once per parser in isolation plus a merged arm, and emit a comparison.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + # Two-parser ablation with the merged arm: + zsgdp benchmark-ablate --input ./docs --output ./bench/ablation \\ + --parser docling --parser pymupdf + + # Three parsers, no merged arm: + zsgdp benchmark-ablate --input ./docs --output ./bench/ablation \\ + --parser docling --parser pymupdf --parser text --no-merged + """ + ), + ) + ablate_parser.add_argument("--input", required=True, help="Input folder of documents.") + ablate_parser.add_argument( + "--dataset", + required=False, + default="custom_folder", + help="Dataset loader name. 'custom' aliases to custom_folder.", + ) + ablate_parser.add_argument("--output", required=False, default="./benchmarks/ablations") + ablate_parser.add_argument("--config", help="Optional YAML config path.") + ablate_parser.add_argument( + "--parser", + action="append", + dest="ablate_parsers", + required=True, + help="Parser to include as an ablation arm. Repeat to add more.", + ) + ablate_parser.add_argument( + "--no-merged", + dest="include_merged", + action="store_false", + default=True, + help="Skip the all-parsers-together merged arm.", + ) + + preflight_parser = subparsers.add_parser( + "preflight", + help="Run all local guards (unit tests, regression fixtures, space-check, parser registry) before pushing to a Space.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + # Standard preflight (~10s): + zsgdp preflight --root . + + # Add an end-to-end benchmark smoke (adds ~1-3s): + zsgdp preflight --root . --benchmark + + # Skip slow steps when iterating locally: + zsgdp preflight --root . --skip-unit + """ + ), + ) + preflight_parser.add_argument("--root", default=".", help="Repository root to check.") + preflight_parser.add_argument("--skip-unit", action="store_true", help="Skip the unittest discovery step.") + preflight_parser.add_argument("--skip-regression", action="store_true", help="Skip the regression fixture step.") + preflight_parser.add_argument("--skip-space-check", action="store_true", help="Skip the Space readiness check.") + preflight_parser.add_argument("--skip-parsers", action="store_true", help="Skip the parser registry sanity step.") + preflight_parser.add_argument( + "--benchmark", + action="store_true", + help="Also run an end-to-end benchmark against tests/regression/fixtures (off by default).", + ) + + combine_parser = subparsers.add_parser( + "combine-benchmarks", + help="Combine multiple benchmark summaries into a cross-dataset comparison.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + # Compare OmniDocBench vs DocLayNet runs: + zsgdp combine-benchmarks \\ + --input ./bench/omni --label omnidocbench \\ + --input ./bench/doclay --label doclaynet \\ + --output ./bench/cross + + # Without explicit labels (uses dataset_name from each summary): + zsgdp combine-benchmarks \\ + --input ./bench/omni \\ + --input ./bench/doclay \\ + --output ./bench/cross + """ + ), + ) + combine_parser.add_argument( + "--input", + action="append", + dest="combine_inputs", + required=True, + help="Benchmark output directory or results.json path. Repeat once per dataset.", + ) + combine_parser.add_argument( + "--label", + action="append", + dest="combine_labels", + help="Optional label per --input (defaults to dataset_name from each summary).", + ) + combine_parser.add_argument("--output", required=True, help="Output directory for the comparison artifacts.") + + export_parser = subparsers.add_parser( + "export-chunks", + help="Export chunks from a parsed document directory.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + zsgdp export-chunks --parsed ./out/sample --format jsonl --output ./chunks.jsonl + zsgdp export-chunks --parsed ./out/sample --format json --output ./chunks.json + """ + ), + ) + export_parser.add_argument("--parsed", required=True, help="Parsed document output directory.") + export_parser.add_argument("--format", choices=["jsonl", "json"], default="jsonl", help="Output format.") + export_parser.add_argument("--output", required=True, help="Output file path.") + + validate_parser = subparsers.add_parser( + "validate-artifacts", + help="Validate artifact_manifest.json checksums.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_epilog( + """ + zsgdp validate-artifacts --parsed ./out/sample + zsgdp validate-artifacts --parsed ./out/sample --output ./validation.json + """ + ), + ) + validate_parser.add_argument("--parsed", required=True, help="Parsed document output directory.") + validate_parser.add_argument("--output", help="Optional JSON validation report path.") + + args = parser.parse_args(argv) + if args.command == "parse": + parsed = parse_document(args.input, args.output, config_path=args.config, selected_parsers=_selected_parsers(args)) + _print_parse_summary(parsed, Path(args.output)) + return 0 + + if args.command == "parse-folder": + if args.workers < 1: + parser.error("parse-folder --workers must be >= 1") + if args.gpu_workers < 0: + parser.error("parse-folder --gpu-workers must be >= 0") + input_dir = Path(args.input) + if not input_dir.is_dir(): + parser.error(f"parse-folder input must be a folder: {input_dir}") + summary = _parse_folder( + input_dir, + Path(args.output), + config_path=args.config, + selected_parsers=_selected_parsers(args), + workers=args.workers, + gpu_workers=args.gpu_workers, + ) + for result in summary["results"]: + if result["status"] == "parsed": + print( + f"parsed {result['file']} -> {result['output']} " + f"score={result['quality_score']:.2f} chunks={result['chunks']}" + ) + else: + print(f"failed {result['file']} -> {result['output']} error={result['error']}") + print( + f"parsed {summary['success_count']} file(s), " + f"failed {summary['failure_count']} file(s), " + f"workers={summary['workers']} gpu_workers={summary['gpu_workers']}" + ) + return 0 if summary["failure_count"] == 0 else 1 + + if args.command == "profile": + print(dumps_json(profile_document(args.input))) + return 0 + + if args.command == "gpu-status": + print(dumps_json(collect_gpu_runtime_status(load_config(args.config)).to_dict())) + return 0 + + if args.command == "space-check": + report = check_huggingface_space(args.root, config_path=args.config) + if args.output: + write_json(args.output, report) + print( + f"valid={report['valid']} target={report['target']} space={report['space_name']} " + f"failures={report['failure_count']} warnings={report['warning_count']}" + ) + return 0 if report["valid"] else 1 + + if args.command == "run-gpu-tasks": + report = run_gpu_task_manifest( + args.input, + config=load_config(args.config), + output_path=args.output, + dry_run=not args.execute, + ) + print( + f"gpu_tasks={report['task_count']} batches={report['batch_count']} " + f"ready={report['ready_count']} blocked={report['blocked_count']} " + f"executed={report.get('executed_count', 0)} failed={report.get('failed_count', 0)} " + f"report={args.output}" + ) + return 0 + + if args.command == "parsers": + config = load_config() + for name in parser_names(): + adapter = get_parser(name) + enabled = config.get("parsers", {}).get(name, {}).get("enabled", False) + print(f"{name}\tenabled={enabled}\tavailable={adapter.available()}") + return 0 + + if args.command == "benchmark": + if not args.input: + parser.error("benchmark requires --input") + summary = run_parser_benchmark( + args.input, + args.output, + config_path=args.config, + selected_parsers=_selected_parsers(args), + dataset_name=args.dataset, + ) + print(f"dataset={summary.get('dataset_name', args.dataset)}") + print(f"documents={summary['document_count']} mean_quality_score={summary['mean_quality_score']:.2f}") + print(f"leaderboard={Path(args.output) / 'leaderboard.csv'}") + return 0 + + if args.command == "benchmark-ablate": + comparison = run_parser_ablations( + args.input, + args.output, + parsers=args.ablate_parsers, + config_path=args.config, + dataset_name=args.dataset, + include_merged=args.include_merged, + ) + print(f"arms={comparison['arm_count']} comparison={Path(args.output) / 'ablation_comparison.csv'}") + for row in comparison["rows"]: + quality = row.get("mean_quality_score", 0.0) + layout = row.get("mean_layout_f1", 0.0) + recall = row.get("mean_retrieval_recall_at_1", 0.0) + print(f" arm={row['arm']:<14} quality={quality:.2f} layout_f1={layout:.2f} recall@1={recall:.2f}") + return 0 + + if args.command == "preflight": + import sys as _sys + + result = run_preflight( + root=args.root, + skip_unit=args.skip_unit, + skip_regression=args.skip_regression, + skip_space_check=args.skip_space_check, + skip_parsers=args.skip_parsers, + run_benchmark=args.benchmark, + ) + print(format_summary(result)) + if not result.passed: + failures = format_failures(result) + if failures: + print("\n" + failures, file=_sys.stderr) + return 1 + return 0 + + if args.command == "combine-benchmarks": + labels = list(args.combine_labels or []) + if labels and len(labels) != len(args.combine_inputs): + parser.error("combine-benchmarks: --label must be passed once per --input or omitted entirely.") + pairs = [] + for index, source in enumerate(args.combine_inputs): + label = labels[index] if labels else None + if label is None: + from zsgdp.benchmarks.cross_dataset import _load_summary + + summary_for_default = _load_summary(source) + label = str(summary_for_default.get("dataset_name") or f"run_{index + 1}") + pairs.append((label, source)) + comparison = combine_benchmark_summaries(pairs) + write_cross_dataset_outputs(comparison, args.output) + print(f"combined {comparison['run_count']} run(s) -> {args.output}") + for row in comparison["dataset_summary"]: + print( + f" {row['label']:<14} docs={row.get('document_count') or 0} " + f"layout_f1={row.get('mean_layout_f1') or 0:.2f} " + f"recall@5={row.get('mean_retrieval_recall_at_5') or 0:.2f}" + ) + return 0 + + if args.command == "export-chunks": + exported = _export_chunks(Path(args.parsed), Path(args.output), args.format) + print(f"exported {exported} chunk(s) -> {args.output}") + return 0 + + if args.command == "validate-artifacts": + report = validate_artifact_manifest(args.parsed) + if args.output: + write_json(args.output, report) + print(f"valid={report['valid']} checked={report['checked_count']} errors={len(report['errors'])}") + return 0 if report["valid"] else 1 + + parser.error(f"Unhandled command: {args.command}") + return 2 + + +def _print_parse_summary(parsed, output_dir: Path) -> None: + print(f"doc_id={parsed.doc_id}") + print(f"file_type={parsed.file_type}") + print(f"elements={len(parsed.elements)} tables={len(parsed.tables)} figures={len(parsed.figures)} chunks={len(parsed.chunks)}") + print(f"quality_score={parsed.quality_report.score:.2f} blocking={parsed.quality_report.has_blocking_failures}") + print(f"output={output_dir}") + + +def _selected_parsers(args) -> list[str] | None: + selected = list(getattr(args, "parsers", None) or []) + selected.extend(getattr(args, "parser_list", None) or []) + return selected or None + + +@dataclass(slots=True) +class _FolderParseJob: + index: int + path: Path + output_dir: Path + + +def _parse_folder( + input_dir: Path, + output_dir: Path, + *, + config_path: str | Path | None, + selected_parsers: Sequence[str] | None, + workers: int, + gpu_workers: int = 0, +) -> dict: + if not input_dir.is_dir(): + raise NotADirectoryError(f"Input folder does not exist: {input_dir}") + + output_dir.mkdir(parents=True, exist_ok=True) + jobs = _build_folder_jobs(input_dir, output_dir) + if not jobs: + return { + "workers": workers, + "gpu_workers": gpu_workers, + "success_count": 0, + "failure_count": 0, + "results": [], + } + + if workers == 1: + results = [ + _parse_folder_job(job, config_path=config_path, selected_parsers=selected_parsers) + for job in jobs + ] + else: + results_by_index: list[dict | None] = [None] * len(jobs) + max_workers = min(workers, len(jobs)) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_job = { + executor.submit( + _parse_folder_job, + job, + config_path=config_path, + selected_parsers=selected_parsers, + ): job + for job in jobs + } + for future in as_completed(future_to_job): + job = future_to_job[future] + results_by_index[job.index] = future.result() + results = [result for result in results_by_index if result is not None] + + failures = [result for result in results if result["status"] != "parsed"] + return { + "workers": min(workers, len(jobs)), + "gpu_workers": max(gpu_workers, 0), + "success_count": len(results) - len(failures), + "failure_count": len(failures), + "results": results, + } + + +def _build_folder_jobs(input_dir: Path, output_dir: Path) -> list[_FolderParseJob]: + used_names: set[str] = set() + jobs: list[_FolderParseJob] = [] + for index, path in enumerate(sorted(item for item in input_dir.iterdir() if item.is_file())): + jobs.append( + _FolderParseJob( + index=index, + path=path, + output_dir=output_dir / _unique_output_name(path, used_names), + ) + ) + return jobs + + +def _unique_output_name(path: Path, used_names: set[str]) -> str: + base_name = path.stem or path.name + candidates = [base_name] + if path.suffix: + candidates.append(f"{base_name}-{path.suffix.lstrip('.')}") + + suffix = 2 + while True: + for candidate in candidates: + key = candidate.casefold() + if key not in used_names: + used_names.add(key) + return candidate + candidates = [f"{base_name}-{suffix}"] + suffix += 1 + + +def _parse_folder_job( + job: _FolderParseJob, + *, + config_path: str | Path | None, + selected_parsers: Sequence[str] | None, +) -> dict: + try: + parsed = parse_document( + job.path, + job.output_dir, + config_path=config_path, + selected_parsers=selected_parsers, + ) + except Exception as exc: + return { + "status": "failed", + "file": job.path.name, + "output": str(job.output_dir), + "error": str(exc), + } + return { + "status": "parsed", + "file": job.path.name, + "output": str(job.output_dir), + "doc_id": parsed.doc_id, + "file_type": parsed.file_type, + "quality_score": parsed.quality_report.score, + "blocking": parsed.quality_report.has_blocking_failures, + "elements": len(parsed.elements), + "tables": len(parsed.tables), + "figures": len(parsed.figures), + "chunks": len(parsed.chunks), + } + + +def _export_chunks(parsed_dir: Path, output_path: Path, fmt: str) -> int: + chunks_path = parsed_dir / "chunks.jsonl" + if not chunks_path.exists(): + raise FileNotFoundError(f"Missing chunks artifact: {chunks_path}") + output_path.parent.mkdir(parents=True, exist_ok=True) + + if fmt == "jsonl": + shutil.copyfile(chunks_path, output_path) + return _count_jsonl(chunks_path) + + records = [ + json.loads(line) + for line in chunks_path.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + output_path.write_text(json.dumps(records, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + return len(records) + + +def _count_jsonl(path: Path) -> int: + return sum(1 for line in path.read_text(encoding="utf-8").splitlines() if line.strip()) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/zsgdp/config.py b/zsgdp/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1b67533c514a9c486e0901c743c3a23c71cd6640 --- /dev/null +++ b/zsgdp/config.py @@ -0,0 +1,326 @@ +"""Configuration loading and defaults.""" + +from __future__ import annotations + +import os +from copy import deepcopy +from pathlib import Path +from typing import Any + + +DEFAULT_CONFIG: dict[str, Any] = { + "parsers": { + "text": {"enabled": True}, + "pymupdf": {"enabled": True}, + "docling": {"enabled": False}, + "marker": { + "enabled": False, + "command": None, + "timeout_seconds": 300, + "output_args": ["--output_dir", "{output_dir}", "--output_format", "markdown"], + "extra_args": [], + }, + "mineru": { + "enabled": False, + "command": None, + "timeout_seconds": 600, + "output_args": ["--output_dir", "{output_dir}"], + "extra_args": [], + }, + "olmocr": { + "enabled": False, + "command": None, + "timeout_seconds": 600, + "output_args": ["--output_dir", "{output_dir}"], + "extra_args": [], + }, + "paddleocr": { + "enabled": False, + "command": None, + "timeout_seconds": 600, + "output_args": ["--output_dir", "{output_dir}"], + "extra_args": [], + }, + "unstructured": {"enabled": False}, + }, + "routing": { + "run_multiple_on_hard_pages": True, + "max_primary_parsers_per_page": 2, + "hard_page_threshold": 0.65, + "scanned_text_threshold": 0.40, + "table_density_threshold": 0.25, + "formula_density_threshold": 0.15, + "figure_density_threshold": 0.20, + }, + "repair": { + "enabled": True, + "max_iterations": 3, + "gpu_escalation": True, + "execute_gpu_escalations": False, + "table_repair": True, + "reading_order_repair": True, + "figure_repair": True, + "ocr_repair": True, + }, + "gpu": { + "backend": "transformers", + "provider": "huggingface_spaces", + "space_name": "zeroshotGPU", + "batch_pages": True, + "validate_tasks": True, + "max_batch_size": 4, + "max_gpu_seconds_per_doc": 120, + "max_vlm_calls_per_doc": 30, + "models": { + "vlm": { + "model_id": "Qwen/Qwen2.5-VL-3B-Instruct", + "task": "image-text-to-text", + "device": "auto", + "dtype": "bfloat16", + "max_batch_size": 1, + }, + "ocr": { + "model_id": "Qwen/Qwen2.5-VL-3B-Instruct", + "task": "document-ocr", + "device": "auto", + "dtype": "bfloat16", + "max_batch_size": 1, + }, + "table": { + "model_id": "Qwen/Qwen2.5-VL-3B-Instruct", + "task": "table-repair", + "device": "auto", + "dtype": "bfloat16", + "max_batch_size": 1, + }, + "embedding": { + "model_id": "jinaai/jina-embeddings-v3", + "task": "retrieval.passage", + "device": "auto", + "dtype": "bfloat16", + "max_batch_size": 16, + }, + }, + "task_model_roles": { + "vlm_route_repair": "vlm", + "ocr_page": "ocr", + "table_vlm_repair": "table", + "figure_description": "vlm", + }, + }, + "pdf": { + "render_pages": True, + "render_dpi": 150, + "crop_tables": True, + "crop_figures": True, + "asset_dir": "assets", + }, + "quality": { + "accept_threshold": 0.88, + "blocking_failures": [ + "empty_page", + "invalid_table", + "missing_text_coverage", + "reading_order_failure", + ], + }, + "chunking": { + "enabled": True, + "planner": "agentic", + "baseline_strategy": "recursive_structure", + "target_tokens": 512, + "min_tokens": 120, + "overlap_ratio": 0.15, + "parent_child": True, + "parent_target_tokens": 1600, + "page_level_for_paginated_docs": True, + "table_chunks": True, + "figure_chunks": True, + "contextual_prefix": False, + "contextual_retrieval": False, + "semantic_similarity_threshold": 0.18, + "max_propositions_per_source": 8, + "max_proposition_chunks": 64, + "semantic_chunking": False, + "late_chunking": False, + "vision_guided": False, + "agentic_proposition_chunking": False, + "strategy_ladder": [ + "fixed_token_baseline", + "recursive_structure", + "metadata_enriched", + "parent_child", + "contextual_retrieval", + "late_chunking", + "semantic_chunking", + "vision_guided", + "agentic_proposition", + ], + }, + "deployment": { + "target": "huggingface_spaces", + "gpu_models_target": "zeroshotGPU", + }, +} + + +def load_env_file(path: str | Path | None = None, *, override: bool = False) -> dict[str, str]: + """Read a `.env` file and populate os.environ for the current process. + + Each line is `KEY=VALUE`; blank lines and `#` comments are skipped. Values + may be quoted with single or double quotes (one set; nested quotes are not + handled — keep secrets simple). When `override` is False (the default), + pre-existing environment variables win — Space-side env config takes + precedence over a committed `.env`. + + Returns the mapping that was applied so callers can audit what was loaded + without re-reading the file. + """ + + env_path = Path(path) if path else Path(".env") + if not env_path.exists(): + return {} + applied: dict[str, str] = {} + for raw in env_path.read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if not line or line.startswith("#"): + continue + if line.startswith("export "): + line = line[len("export ") :].strip() + key, sep, value = line.partition("=") + if not sep: + continue + key = key.strip() + value = value.strip() + if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")): + value = value[1:-1] + if not key: + continue + if not override and key in os.environ: + continue + os.environ[key] = value + applied[key] = value + return applied + + +def hf_token() -> str | None: + """Return the Hugging Face token from any of the conventional env vars. + + Order: HF_TOKEN, HUGGING_FACE_HUB_TOKEN, HUGGINGFACE_TOKEN, HF_ACCESS_TOKEN. + None of these are read at import time; downstream callers (transformers / + sentence-transformers / the model_server) pick up whichever is set when + they need network auth. + """ + + for var in ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "HUGGINGFACE_TOKEN", "HF_ACCESS_TOKEN"): + value = os.environ.get(var) + if value: + return value + return None + + +def deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + """Recursively merge override into base and return a new dictionary.""" + + merged = deepcopy(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(merged.get(key), dict): + merged[key] = deep_merge(merged[key], value) + else: + merged[key] = deepcopy(value) + return merged + + +def load_config(path: str | Path | None = None, overrides: dict[str, Any] | None = None) -> dict[str, Any]: + """Load config from YAML if provided, then merge with defaults.""" + + config = deepcopy(DEFAULT_CONFIG) + if path: + path_obj = Path(path) + try: + import yaml # type: ignore + except ImportError: + loaded = _load_simple_yaml(path_obj.read_text(encoding="utf-8")) + else: + loaded = yaml.safe_load(path_obj.read_text(encoding="utf-8")) or {} + if not isinstance(loaded, dict): + raise ValueError(f"Config file must contain a YAML mapping: {path_obj}") + config = deep_merge(config, loaded) + + if overrides: + config = deep_merge(config, overrides) + + return config + + +def _load_simple_yaml(text: str) -> dict[str, Any]: + """Parse the small YAML subset used by bundled configs. + + This fallback supports nested mappings, scalar values, and lists of scalar + values. It is not a general YAML parser; install PyYAML for broader syntax. + """ + + raw_lines = [ + (len(line) - len(line.lstrip(" ")), line.strip()) + for line in text.splitlines() + if line.strip() and not line.strip().startswith("#") + ] + root: dict[str, Any] = {} + stack: list[tuple[int, Any]] = [(-1, root)] + + for index, (indent, stripped) in enumerate(raw_lines): + while len(stack) > 1 and indent <= stack[-1][0]: + stack.pop() + parent = stack[-1][1] + + if stripped.startswith("- "): + if not isinstance(parent, list): + raise ValueError("Simple YAML parser found a list item outside a list.") + parent.append(_parse_scalar(stripped[2:].strip())) + continue + + key, separator, value = stripped.partition(":") + if not separator: + raise ValueError(f"Simple YAML parser expected `key: value`, got: {stripped}") + key = key.strip() + value = value.strip() + if not isinstance(parent, dict): + raise ValueError("Simple YAML parser found a mapping entry inside a scalar list.") + + if value: + parent[key] = _parse_scalar(value) + continue + + child: dict[str, Any] | list[Any] + child = [] if _next_child_is_list(raw_lines, index, indent) else {} + parent[key] = child + stack.append((indent, child)) + + return root + + +def _next_child_is_list(lines: list[tuple[int, str]], index: int, indent: int) -> bool: + for next_indent, next_line in lines[index + 1 :]: + if next_indent <= indent: + return False + return next_line.startswith("- ") + return False + + +def _parse_scalar(value: str) -> Any: + lowered = value.lower() + if lowered == "true": + return True + if lowered == "false": + return False + if lowered in {"null", "none"}: + return None + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + return value.strip('"').strip("'") diff --git a/zsgdp/deployment.py b/zsgdp/deployment.py new file mode 100644 index 0000000000000000000000000000000000000000..29713ed83b1a52294b83397c9d0f02ea4dfb97bd --- /dev/null +++ b/zsgdp/deployment.py @@ -0,0 +1,240 @@ +"""Deployment readiness checks for the Hugging Face Space.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from zsgdp.config import load_config +from zsgdp.gpu import collect_gpu_runtime_status + +REQUIRED_SPACE_FILES = ["README.md", "app.py", "requirements.txt"] +REQUIRED_REQUIREMENTS = ["gradio", "pymupdf", "pyyaml", "docling"] +REQUIRED_MODEL_ROLES = ["vlm", "ocr", "table", "embedding"] + + +def check_huggingface_space( + root: str | Path = ".", + *, + config_path: str | Path | None = None, +) -> dict[str, Any]: + """Return a local readiness report for the Hugging Face Spaces deployment.""" + + root_path = Path(root) + checks: list[dict[str, Any]] = [] + config_file = Path(config_path) if config_path else _default_space_config(root_path) + config = load_config(config_file if config_file.exists() else None) + runtime = collect_gpu_runtime_status(config).to_dict() + + _check_required_files(root_path, checks) + front_matter = _read_readme_front_matter(root_path / "README.md") + _check_space_metadata(root_path, front_matter, checks) + _check_requirements(root_path / "requirements.txt", checks) + _check_deployment_config(config, config_file, checks) + _check_model_roles(config, checks) + _check_runtime(runtime, checks) + + failures = [check for check in checks if check["status"] == "fail"] + warnings = [check for check in checks if check["status"] == "warn"] + return { + "valid": not failures, + "target": config.get("deployment", {}).get("target"), + "space_name": config.get("gpu", {}).get("space_name"), + "gpu_models_target": config.get("deployment", {}).get("gpu_models_target"), + "root": str(root_path.resolve()), + "config_path": str(config_file), + "failure_count": len(failures), + "warning_count": len(warnings), + "checks": checks, + "runtime": runtime, + } + + +def _default_space_config(root: Path) -> Path: + docling_config = root / "configs" / "docling.yaml" + return docling_config if docling_config.exists() else root / "configs" / "default.yaml" + + +def _check_required_files(root: Path, checks: list[dict[str, Any]]) -> None: + for relative_path in REQUIRED_SPACE_FILES: + path = root / relative_path + _add_check( + checks, + "required_file", + "pass" if path.exists() else "fail", + f"{relative_path} {'exists' if path.exists() else 'is missing'}", + {"path": relative_path}, + ) + + +def _check_space_metadata(root: Path, front_matter: dict[str, str], checks: list[dict[str, Any]]) -> None: + _add_check( + checks, + "space_sdk", + "pass" if front_matter.get("sdk") == "gradio" else "fail", + "README Space SDK is gradio.", + {"actual": front_matter.get("sdk")}, + ) + app_file = front_matter.get("app_file") + _add_check( + checks, + "space_app_file", + "pass" if app_file == "app.py" and (root / "app.py").exists() else "fail", + "README app_file points to app.py.", + {"actual": app_file}, + ) + _add_check( + checks, + "space_python_version", + "pass" if front_matter.get("python_version") == "3.11" else "warn", + "README declares Python 3.11 for the Space.", + {"actual": front_matter.get("python_version")}, + ) + _add_check( + checks, + "space_hardware", + "pass" if front_matter.get("suggested_hardware") else "warn", + "README declares suggested_hardware for GPU duplication.", + {"actual": front_matter.get("suggested_hardware")}, + ) + + +def _check_requirements(path: Path, checks: list[dict[str, Any]]) -> None: + requirements = _requirements(path) + for package in REQUIRED_REQUIREMENTS: + _add_check( + checks, + "requirement", + "pass" if package in requirements else "fail", + f"requirements.txt includes {package}.", + {"package": package}, + ) + + +def _check_deployment_config(config: dict[str, Any], config_path: Path, checks: list[dict[str, Any]]) -> None: + deployment = config.get("deployment", {}) + gpu = config.get("gpu", {}) + _add_check( + checks, + "config_file", + "pass" if config_path.exists() else "warn", + "Deployment config file is present.", + {"path": str(config_path)}, + ) + _add_check( + checks, + "deployment_target", + "pass" if deployment.get("target") == "huggingface_spaces" else "fail", + "Deployment target is huggingface_spaces.", + {"actual": deployment.get("target")}, + ) + _add_check( + checks, + "gpu_space_name", + "pass" if gpu.get("space_name") == "zeroshotGPU" else "fail", + "GPU Space name is zeroshotGPU.", + {"actual": gpu.get("space_name")}, + ) + _add_check( + checks, + "gpu_models_target", + "pass" if deployment.get("gpu_models_target") == "zeroshotGPU" else "fail", + "GPU/model target is zeroshotGPU.", + {"actual": deployment.get("gpu_models_target")}, + ) + + +def _check_model_roles(config: dict[str, Any], checks: list[dict[str, Any]]) -> None: + models = config.get("gpu", {}).get("models", {}) + task_roles = config.get("gpu", {}).get("task_model_roles", {}) + for role in REQUIRED_MODEL_ROLES: + model = models.get(role, {}) if isinstance(models, dict) else {} + model_id = model.get("model_id") if isinstance(model, dict) else None + _add_check( + checks, + "model_role", + "pass" if model_id else "fail", + f"GPU model role {role} has a model_id.", + {"role": role, "model_id": model_id}, + ) + for task_type in ["vlm_route_repair", "ocr_page", "table_vlm_repair", "figure_description"]: + role = task_roles.get(task_type) if isinstance(task_roles, dict) else None + _add_check( + checks, + "task_model_role", + "pass" if role in REQUIRED_MODEL_ROLES else "fail", + f"GPU task {task_type} maps to a configured model role.", + {"task_type": task_type, "role": role}, + ) + + +def _check_runtime(runtime: dict[str, Any], checks: list[dict[str, Any]]) -> None: + _add_check( + checks, + "runtime_provider", + "pass" if runtime.get("provider") == "huggingface_spaces" else "fail", + "Runtime provider is huggingface_spaces.", + {"actual": runtime.get("provider")}, + ) + _add_check( + checks, + "runtime_environment", + "pass" if runtime.get("running_on_huggingface_space") else "warn", + "Hugging Face Space environment variables are detected.", + {"running_on_huggingface_space": runtime.get("running_on_huggingface_space")}, + ) + _add_check( + checks, + "runtime_accelerator", + "pass" if runtime.get("device") in {"cuda", "mps"} else "warn", + "GPU/MPS accelerator is detected for model-backed repair.", + {"device": runtime.get("device")}, + ) + + +def _read_readme_front_matter(path: Path) -> dict[str, str]: + if not path.exists(): + return {} + lines = path.read_text(encoding="utf-8").splitlines() + if not lines or lines[0].strip() != "---": + return {} + metadata: dict[str, str] = {} + for line in lines[1:]: + if line.strip() == "---": + break + key, separator, value = line.partition(":") + if separator: + metadata[key.strip()] = value.strip().strip('"').strip("'") + return metadata + + +def _requirements(path: Path) -> set[str]: + if not path.exists(): + return set() + packages: set[str] = set() + for line in path.read_text(encoding="utf-8").splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + name = stripped.replace("_", "-").split("[", 1)[0] + for separator in ["==", ">=", "<=", "~=", ">", "<"]: + name = name.split(separator, 1)[0] + packages.add(name.strip().lower().replace("-", "")) + return packages + + +def _add_check( + checks: list[dict[str, Any]], + check_id: str, + status: str, + message: str, + metadata: dict[str, Any] | None = None, +) -> None: + checks.append( + { + "id": check_id, + "status": status, + "message": message, + "metadata": metadata or {}, + } + ) diff --git a/zsgdp/export.py b/zsgdp/export.py new file mode 100644 index 0000000000000000000000000000000000000000..5e006f9e530101d8ed51a0789550d65f6d8ac3df --- /dev/null +++ b/zsgdp/export.py @@ -0,0 +1,45 @@ +"""Output writers for parsed-document artifacts.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Iterable + +from zsgdp.artifacts import write_artifact_manifest +from zsgdp.schema import ParsedDocument +from zsgdp.utils import write_json, write_jsonl + + +def export_parsed_document( + parsed: ParsedDocument, + output_dir: str | Path, + *, + routing_report: Iterable[dict] | None = None, + profile: dict | None = None, +) -> None: + out = Path(output_dir) + out.mkdir(parents=True, exist_ok=True) + write_json(out / "parsed_document.json", parsed) + write_json(out / "quality_report.json", parsed.quality_report) + write_jsonl(out / "elements.jsonl", parsed.elements) + write_jsonl(out / "tables.jsonl", parsed.tables) + write_jsonl(out / "figures.jsonl", parsed.figures) + write_jsonl(out / "chunks.jsonl", parsed.chunks) + (out / "document.md").write_text(parsed.to_markdown(), encoding="utf-8") + if parsed.provenance.get("chunking"): + write_json(out / "chunking_plan.json", parsed.provenance["chunking"]) + if parsed.provenance.get("parser_metrics"): + write_json(out / "parser_metrics.json", parsed.provenance["parser_metrics"]) + if parsed.provenance.get("conflict_report"): + write_json(out / "conflict_report.json", parsed.provenance["conflict_report"]) + if parsed.provenance.get("gpu_runtime"): + write_json(out / "gpu_runtime.json", parsed.provenance["gpu_runtime"]) + if parsed.provenance.get("gpu_tasks"): + write_jsonl(out / "gpu_tasks.jsonl", parsed.provenance["gpu_tasks"]) + if parsed.provenance.get("gpu_task_report"): + write_json(out / "gpu_task_report.json", parsed.provenance["gpu_task_report"]) + if routing_report is not None: + write_json(out / "routing_report.json", list(routing_report)) + if profile is not None: + write_json(out / "profile.json", profile) + write_artifact_manifest(out, parsed) diff --git a/zsgdp/gpu/__init__.py b/zsgdp/gpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a7c219fa3f5e63b2e03fa10ab425547f5a19f8 --- /dev/null +++ b/zsgdp/gpu/__init__.py @@ -0,0 +1,15 @@ +from zsgdp.gpu.model_server import GPUModelConfig +from zsgdp.gpu.runtime import GPURuntimeStatus, collect_gpu_runtime_status +from zsgdp.gpu.runner import dry_run_gpu_tasks, load_gpu_tasks, run_gpu_task_manifest +from zsgdp.gpu.tasks import GPUTask, plan_gpu_tasks + +__all__ = [ + "GPUModelConfig", + "GPURuntimeStatus", + "GPUTask", + "collect_gpu_runtime_status", + "dry_run_gpu_tasks", + "load_gpu_tasks", + "plan_gpu_tasks", + "run_gpu_task_manifest", +] diff --git a/zsgdp/gpu/batching.py b/zsgdp/gpu/batching.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a70785f06656d12e0e036916721d0464356214 --- /dev/null +++ b/zsgdp/gpu/batching.py @@ -0,0 +1,50 @@ +"""Batching helpers for future GPU workers.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Iterable + + +def bucket_pages(page_records: Iterable[dict]) -> dict[str, list[dict]]: + buckets: dict[str, list[dict]] = defaultdict(list) + for page in page_records: + task = str(page.get("task", "parse")) + resolution = str(page.get("resolution_bucket", "default")) + buckets[f"{task}:{resolution}"].append(page) + return dict(buckets) + + +def batch_gpu_tasks(tasks: Iterable[dict], max_batch_size: int = 4) -> list[dict]: + buckets: dict[tuple[str, str, str, str, str, str], list[dict]] = defaultdict(list) + for task in tasks: + key = ( + str(task.get("provider", "huggingface_spaces")), + str(task.get("space_name", "zeroshotGPU")), + str(task.get("backend", "transformers")), + str(task.get("task_type", "unknown")), + str(task.get("model_role", "vlm")), + str(task.get("model_id", "")), + ) + buckets[key].append(task) + + batches: list[dict] = [] + safe_batch_size = max(int(max_batch_size), 1) + for (provider, space_name, backend, task_type, model_role, model_id), bucket in sorted(buckets.items()): + ordered = sorted(bucket, key=lambda item: (-int(item.get("priority", 0)), str(item.get("task_id", "")))) + for offset in range(0, len(ordered), safe_batch_size): + batch_tasks = ordered[offset : offset + safe_batch_size] + batches.append( + { + "batch_id": f"gb{len(batches) + 1}", + "provider": provider, + "space_name": space_name, + "backend": backend, + "task_type": task_type, + "model_role": model_role, + "model_id": model_id or None, + "task_count": len(batch_tasks), + "tasks": batch_tasks, + } + ) + return batches diff --git a/zsgdp/gpu/model_server.py b/zsgdp/gpu/model_server.py new file mode 100644 index 0000000000000000000000000000000000000000..b9954b13ae997ac7dd8a651a64aac16da2759306 --- /dev/null +++ b/zsgdp/gpu/model_server.py @@ -0,0 +1,23 @@ +"""GPU model server configuration.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(slots=True) +class GPUModelConfig: + backend: str = "transformers" + provider: str = "huggingface_spaces" + space_name: str = "zeroshotGPU" + max_batch_size: int = 4 + + @classmethod + def from_config(cls, config: dict) -> "GPUModelConfig": + gpu = config.get("gpu", {}) + return cls( + backend=gpu.get("backend", "transformers"), + provider=gpu.get("provider", "huggingface_spaces"), + space_name=gpu.get("space_name", "zeroshotGPU"), + max_batch_size=int(gpu.get("max_batch_size", 4)), + ) diff --git a/zsgdp/gpu/runner.py b/zsgdp/gpu/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..b88f211e936ca7eed883b3a5a0b247003785534f --- /dev/null +++ b/zsgdp/gpu/runner.py @@ -0,0 +1,75 @@ +"""Dry-run execution for planned GPU task manifests.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from zsgdp.gpu.batching import batch_gpu_tasks +from zsgdp.gpu.runtime import collect_gpu_runtime_status +from zsgdp.gpu.worker import GPUWorker +from zsgdp.utils import write_json + + +def load_gpu_tasks(path: str | Path) -> list[dict[str, Any]]: + task_path = _task_manifest_path(path) + if not task_path.exists(): + raise FileNotFoundError(f"Missing GPU task manifest: {task_path}") + return [ + json.loads(line) + for line in task_path.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + + +def run_gpu_task_manifest( + input_path: str | Path, + *, + config: dict[str, Any], + output_path: str | Path | None = None, + dry_run: bool = True, +) -> dict[str, Any]: + tasks = load_gpu_tasks(input_path) + report = dry_run_gpu_tasks(tasks, config=config, dry_run=dry_run) + if output_path is not None: + write_json(output_path, report) + return report + + +def dry_run_gpu_tasks( + tasks: list[dict[str, Any]], + *, + config: dict[str, Any], + dry_run: bool = True, + runtime_status: dict[str, Any] | None = None, +) -> dict[str, Any]: + runtime = runtime_status or collect_gpu_runtime_status(config).to_dict() + max_batch_size = int(config.get("gpu", {}).get("max_batch_size", runtime.get("max_batch_size", 4))) + batches = batch_gpu_tasks(tasks, max_batch_size=max_batch_size) + worker = GPUWorker(config) + batch_results = [worker.run_batch(batch, dry_run=dry_run) for batch in batches] + + ready_count = sum(result["ready_count"] for result in batch_results) + blocked_count = sum(result["blocked_count"] for result in batch_results) + executed_count = sum(result.get("executed_count", 0) for result in batch_results) + failed_count = sum(result.get("failed_count", 0) for result in batch_results) + report = { + "dry_run": dry_run, + "task_count": len(tasks), + "batch_count": len(batches), + "ready_count": ready_count, + "blocked_count": blocked_count, + "executed_count": executed_count, + "failed_count": failed_count, + "runtime": runtime, + "batches": batch_results, + } + return report + + +def _task_manifest_path(path: str | Path) -> Path: + path_obj = Path(path) + if path_obj.is_dir(): + return path_obj / "gpu_tasks.jsonl" + return path_obj diff --git a/zsgdp/gpu/runtime.py b/zsgdp/gpu/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..8f51a5d23a44e1fe078f4029364236cb34b50fad --- /dev/null +++ b/zsgdp/gpu/runtime.py @@ -0,0 +1,112 @@ +"""GPU runtime and Hugging Face Spaces status helpers.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import os +from typing import Any + +from zsgdp.gpu.model_server import GPUModelConfig +from zsgdp.utils import to_plain_data + + +@dataclass(slots=True) +class GPURuntimeStatus: + provider: str + backend: str + space_name: str + gpu_models_target: str + running_on_huggingface_space: bool + space_id: str | None + hardware: str | None + device: str + torch_available: bool + torch_version: str | None = None + cuda_available: bool = False + cuda_device_count: int = 0 + cuda_devices: list[str] = field(default_factory=list) + mps_available: bool = False + batch_pages: bool = True + max_batch_size: int = 4 + max_gpu_seconds_per_doc: float = 120.0 + max_vlm_calls_per_doc: int = 30 + configured_models: dict[str, Any] = field(default_factory=dict) + notes: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +def collect_gpu_runtime_status(config: dict[str, Any]) -> GPURuntimeStatus: + gpu = config.get("gpu", {}) + deployment = config.get("deployment", {}) + model_config = GPUModelConfig.from_config(config) + torch_status = _torch_status() + running_on_space = bool(os.environ.get("SPACE_ID") or os.environ.get("SPACE_HOST")) + hardware = os.environ.get("SPACE_HARDWARE") or os.environ.get("HF_SPACE_HARDWARE") + device = _preferred_device(torch_status) + + notes: list[str] = [] + if not running_on_space: + notes.append("Hugging Face Spaces environment variables were not detected; this looks like a local run.") + if device == "cpu": + notes.append("No CUDA or MPS accelerator was detected by PyTorch.") + elif device == "cuda": + notes.append("CUDA accelerator detected.") + elif device == "mps": + notes.append("Apple MPS accelerator detected.") + if model_config.provider == "huggingface_spaces" and not hardware: + notes.append("No Space hardware label was found; set hardware in the Space settings for GPU deployment.") + + return GPURuntimeStatus( + provider=model_config.provider, + backend=model_config.backend, + space_name=model_config.space_name, + gpu_models_target=str(deployment.get("gpu_models_target", model_config.space_name)), + running_on_huggingface_space=running_on_space, + space_id=os.environ.get("SPACE_ID"), + hardware=hardware, + device=device, + batch_pages=bool(gpu.get("batch_pages", True)), + max_batch_size=model_config.max_batch_size, + max_gpu_seconds_per_doc=float(gpu.get("max_gpu_seconds_per_doc", 120)), + max_vlm_calls_per_doc=int(gpu.get("max_vlm_calls_per_doc", 30)), + configured_models=dict(gpu.get("models", {})), + notes=notes, + **torch_status, + ) + + +def _torch_status() -> dict[str, Any]: + try: + import torch # type: ignore + except Exception: + return { + "torch_available": False, + "torch_version": None, + "cuda_available": False, + "cuda_device_count": 0, + "cuda_devices": [], + "mps_available": False, + } + + cuda_available = bool(torch.cuda.is_available()) + cuda_device_count = int(torch.cuda.device_count()) if cuda_available else 0 + cuda_devices = [torch.cuda.get_device_name(index) for index in range(cuda_device_count)] + mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()) + return { + "torch_available": True, + "torch_version": getattr(torch, "__version__", None), + "cuda_available": cuda_available, + "cuda_device_count": cuda_device_count, + "cuda_devices": cuda_devices, + "mps_available": mps_available, + } + + +def _preferred_device(torch_status: dict[str, Any]) -> str: + if torch_status.get("cuda_available"): + return "cuda" + if torch_status.get("mps_available"): + return "mps" + return "cpu" diff --git a/zsgdp/gpu/tasks.py b/zsgdp/gpu/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a5d6081f445d9c399ce2630f86eb9f671f61aa --- /dev/null +++ b/zsgdp/gpu/tasks.py @@ -0,0 +1,286 @@ +"""GPU/VLM task planning for model-backed repair and visual chunking.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from zsgdp.gpu.runtime import collect_gpu_runtime_status +from zsgdp.routing import RouteDecision +from zsgdp.schema import DocumentProfile, ParsedDocument +from zsgdp.utils import to_plain_data + + +@dataclass(slots=True) +class GPUTask: + task_id: str + task_type: str + doc_id: str + page_nums: list[int] + priority: int + status: str = "planned" + region_id: str | None = None + image_path: str | None = None + bbox: Any = None + model_role: str = "vlm" + provider: str = "huggingface_spaces" + space_name: str = "zeroshotGPU" + backend: str = "transformers" + gpu_models_target: str = "zeroshotGPU" + reason: str = "" + model_id: str | None = None + budget: dict[str, Any] = field(default_factory=dict) + inputs: dict[str, Any] = field(default_factory=dict) + model_config: dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +def plan_gpu_tasks( + profile: DocumentProfile, + parsed: ParsedDocument, + config: dict[str, Any], + route_decisions: list[RouteDecision] | None = None, + runtime_status: dict[str, Any] | None = None, +) -> list[dict[str, Any]]: + """Plan model-backed tasks without executing them. + + The pipeline remains self-hosted and deterministic by default. These tasks + make the GPU/model work explicit for Hugging Face Spaces workers, VLM repair, + and future asynchronous execution. + """ + + runtime = runtime_status or collect_gpu_runtime_status(config).to_dict() + budget = _budget(config) + tasks: list[GPUTask] = [] + + for decision in route_decisions or []: + for expert in decision.experts: + if expert.startswith("vlm_"): + tasks.append( + _task( + tasks, + "vlm_route_repair", + parsed.doc_id, + [decision.page_id], + priority=90, + runtime=runtime, + config=config, + budget=budget, + reason=f"Router selected {expert}: {decision.reason}", + inputs={"expert": expert, "labels": decision.labels}, + ) + ) + + for page in profile.pages: + if _needs_ocr_page_task(page, config): + tasks.append( + _task( + tasks, + "ocr_page", + parsed.doc_id, + [page.page_num], + priority=80, + runtime=runtime, + config=config, + budget=budget, + image_path=_rendered_page_path(parsed, page.page_num), + reason="Page profile indicates scanned or weak digital text.", + inputs={"page_profile": page.to_dict()}, + ) + ) + + repair = config.get("repair", {}) + chunking = config.get("chunking", {}) + for table in parsed.tables: + crop_path = table.provenance.get("crop_path") + if crop_path and (repair.get("table_repair", True) or chunking.get("vision_guided", False)): + tasks.append( + _task( + tasks, + "table_vlm_repair", + parsed.doc_id, + table.page_nums, + priority=70, + runtime=runtime, + config=config, + budget=budget, + region_id=table.table_id, + image_path=crop_path, + bbox=table.bbox, + reason="Table crop is available for VLM verification or repair.", + inputs={ + "markdown": table.markdown, + "natural_language_rendering": table.natural_language_rendering, + }, + metadata={"source_parser": table.source_parser, "confidence": table.confidence}, + ) + ) + + for figure in parsed.figures: + if figure.image_path and (not figure.vlm_description or chunking.get("vision_guided", False)): + tasks.append( + _task( + tasks, + "figure_description", + parsed.doc_id, + [figure.page_num], + priority=60, + runtime=runtime, + config=config, + budget=budget, + region_id=figure.figure_id, + image_path=figure.image_path, + bbox=figure.bbox, + reason="Figure image is available for model description.", + inputs={"caption": figure.caption, "chart_data": figure.chart_data}, + metadata={"source_parser": figure.source_parser, "confidence": figure.confidence}, + ) + ) + + max_calls = int(config.get("gpu", {}).get("max_vlm_calls_per_doc", 30)) + planned = sorted(tasks, key=lambda item: (-item.priority, item.task_id))[:max_calls] + return [task.to_dict() for task in planned] + + +def make_gpu_task( + task_id: str, + task_type: str, + doc_id: str, + page_nums: list[int], + *, + priority: int, + runtime: dict[str, Any], + config: dict[str, Any], + budget: dict[str, Any], + reason: str, + region_id: str | None = None, + image_path: str | None = None, + bbox: Any = None, + inputs: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, +) -> dict[str, Any]: + model_role, model_config = _model_details(config, task_type) + task = GPUTask( + task_id=task_id, + task_type=task_type, + doc_id=doc_id, + page_nums=page_nums, + priority=priority, + region_id=region_id, + image_path=image_path, + bbox=bbox, + provider=str(runtime.get("provider", "huggingface_spaces")), + space_name=str(runtime.get("space_name", "zeroshotGPU")), + backend=str(runtime.get("backend", "transformers")), + gpu_models_target=str(runtime.get("gpu_models_target", "zeroshotGPU")), + model_role=model_role, + model_id=_model_id(model_config), + reason=reason, + budget=budget, + inputs=inputs or {}, + model_config=model_config, + metadata=metadata or {}, + ) + return task.to_dict() + + +def _task( + tasks: list[GPUTask], + task_type: str, + doc_id: str, + page_nums: list[int], + *, + priority: int, + runtime: dict[str, Any], + config: dict[str, Any], + budget: dict[str, Any], + reason: str, + region_id: str | None = None, + image_path: str | None = None, + bbox: Any = None, + inputs: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, +) -> GPUTask: + return GPUTask(**make_gpu_task( + f"gt{len(tasks) + 1}", + task_type, + doc_id, + page_nums, + priority=priority, + runtime=runtime, + config=config, + budget=budget, + reason=reason, + region_id=region_id, + image_path=image_path, + bbox=bbox, + inputs=inputs, + metadata=metadata, + )) + + +def _budget(config: dict[str, Any]) -> dict[str, Any]: + gpu = config.get("gpu", {}) + repair = config.get("repair", {}) + return { + "max_gpu_seconds_per_doc": float(gpu.get("max_gpu_seconds_per_doc", 120)), + "max_vlm_calls_per_doc": int(gpu.get("max_vlm_calls_per_doc", 30)), + "max_repair_iterations": int(repair.get("max_iterations", 3)), + "max_batch_size": int(gpu.get("max_batch_size", 4)), + } + + +def _model_details(config: dict[str, Any], task_type: str) -> tuple[str, dict[str, Any]]: + gpu = config.get("gpu", {}) + task_roles = gpu.get("task_model_roles", {}) + model_role = _default_model_role(task_type) + if isinstance(task_roles, dict): + model_role = str(task_roles.get(task_type, model_role)) + + models = gpu.get("models", {}) + if not isinstance(models, dict): + return model_role, {} + model_config = models.get(model_role, {}) + if isinstance(model_config, dict): + return model_role, dict(model_config) + if model_config: + return model_role, {"model_id": str(model_config)} + return model_role, {} + + +def _default_model_role(task_type: str) -> str: + if task_type == "ocr_page": + return "ocr" + if task_type == "table_vlm_repair": + return "table" + return "vlm" + + +def _model_id(model_config: dict[str, Any]) -> str | None: + model_id = model_config.get("model_id") + return str(model_id) if model_id else None + + +def _needs_ocr_page_task(page: Any, config: dict[str, Any]) -> bool: + routing = config.get("routing", {}) + scanned_threshold = float(routing.get("scanned_text_threshold", 0.40)) + return page.scanned_score > scanned_threshold or ( + page.digital_text_chars == 0 and page.digital_text_quality < 0.25 + ) + + +def _rendered_page_path(parsed: ParsedDocument, page_num: int) -> str | None: + for page in parsed.pages: + if int(page.get("page_num", 1)) != page_num: + continue + for parser_page in page.get("parser_pages", []): + rendered = parser_page.get("rendered_page") + if isinstance(rendered, dict) and rendered.get("image_path"): + return str(rendered["image_path"]) + rendered = page.get("rendered_page") + if isinstance(rendered, dict) and rendered.get("image_path"): + return str(rendered["image_path"]) + return None diff --git a/zsgdp/gpu/transformers_client.py b/zsgdp/gpu/transformers_client.py new file mode 100644 index 0000000000000000000000000000000000000000..a68c1b344c59ec4c9102938b606c507555211a7d --- /dev/null +++ b/zsgdp/gpu/transformers_client.py @@ -0,0 +1,75 @@ +"""Transformers backend client.""" + +from __future__ import annotations + +from functools import cached_property +from typing import Any + +from zsgdp.gpu.worker_prompts import prompt_for_task + + +class TransformersClient: + def __init__(self, model_id: str | None, model_config: dict[str, Any] | None = None): + self.model_id = model_id + self.model_config = model_config or {} + + @property + def available(self) -> bool: + if not self.model_id: + return False + try: + import transformers # noqa: F401 + except Exception: + return False + return True + + def execute_task(self, task: dict[str, Any]) -> dict[str, Any]: + if not self.available: + return { + "status": "backend_unavailable", + "error": "Transformers is not installed or model_id is missing.", + } + + prompt = prompt_for_task(task) + image_path = task.get("image_path") + try: + if image_path: + output = self._pipeline({"image": str(image_path), "text": prompt}) + else: + output = self._pipeline(prompt) + except Exception as exc: + return {"status": "execution_failed", "error": str(exc)} + + return { + "status": "executed", + "text": _extract_text(output), + "raw_output": output, + } + + @cached_property + def _pipeline(self): + from transformers import pipeline # type: ignore + + task = str(self.model_config.get("task", "image-text-to-text")) + kwargs: dict[str, Any] = {"model": self.model_id} + dtype = self.model_config.get("dtype") + device = self.model_config.get("device") + if dtype: + kwargs["torch_dtype"] = dtype + if device and device != "auto": + kwargs["device"] = device + elif device == "auto": + kwargs["device_map"] = "auto" + return pipeline(task, **kwargs) + + +def _extract_text(output: Any) -> str: + if isinstance(output, str): + return output + if isinstance(output, list) and output: + return _extract_text(output[0]) + if isinstance(output, dict): + for key in ("generated_text", "text", "summary_text", "answer"): + if output.get(key): + return str(output[key]) + return str(output) diff --git a/zsgdp/gpu/vllm_client.py b/zsgdp/gpu/vllm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..627a598ccea1bec9f565e70dcc3e0d064e373581 --- /dev/null +++ b/zsgdp/gpu/vllm_client.py @@ -0,0 +1,77 @@ +"""vLLM/OpenAI-compatible backend client.""" + +from __future__ import annotations + +import base64 +import json +from pathlib import Path +from typing import Any +from urllib import request +from urllib.error import URLError + +from zsgdp.gpu.worker_prompts import prompt_for_task + + +class VLLMClient: + def __init__(self, endpoint: str | None, model_id: str | None = None, api_key: str | None = None): + self.endpoint = endpoint.rstrip("/") if endpoint else None + self.model_id = model_id or "default" + self.api_key = api_key + + @property + def available(self) -> bool: + return bool(self.endpoint) + + def execute_task(self, task: dict[str, Any]) -> dict[str, Any]: + if not self.available: + return {"status": "backend_unavailable", "error": "vLLM endpoint is not configured."} + + payload = { + "model": self.model_id, + "messages": [{"role": "user", "content": self._message_content(task)}], + "temperature": 0, + } + data = json.dumps(payload).encode("utf-8") + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + req = request.Request(f"{self.endpoint}/v1/chat/completions", data=data, headers=headers, method="POST") + try: + with request.urlopen(req, timeout=120) as response: + body = json.loads(response.read().decode("utf-8")) + except (OSError, URLError, json.JSONDecodeError) as exc: + return {"status": "execution_failed", "error": str(exc)} + + return { + "status": "executed", + "text": _extract_chat_text(body), + "raw_output": body, + } + + def _message_content(self, task: dict[str, Any]) -> Any: + prompt = prompt_for_task(task) + image_path = task.get("image_path") + if not image_path: + return prompt + data_url = _image_data_url(Path(str(image_path))) + return [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": data_url}}, + ] + + +def _image_data_url(path: Path) -> str: + suffix = path.suffix.lower().lstrip(".") or "png" + mime = "jpeg" if suffix in {"jpg", "jpeg"} else suffix + encoded = base64.b64encode(path.read_bytes()).decode("ascii") + return f"data:image/{mime};base64,{encoded}" + + +def _extract_chat_text(body: dict[str, Any]) -> str: + choices = body.get("choices") or [] + if choices: + message = choices[0].get("message", {}) + content = message.get("content") + if content: + return str(content) + return json.dumps(body, ensure_ascii=False) diff --git a/zsgdp/gpu/worker.py b/zsgdp/gpu/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..019a7ce1385d17ded020b9bc373f8ca2f540f951 --- /dev/null +++ b/zsgdp/gpu/worker.py @@ -0,0 +1,143 @@ +"""GPU worker extension point.""" + +from __future__ import annotations + +from pathlib import Path +from time import perf_counter + +from zsgdp.gpu.transformers_client import TransformersClient +from zsgdp.gpu.vllm_client import VLLMClient +from zsgdp.logging_config import get_logger + +logger = get_logger(__name__) + + +class GPUWorker: + def __init__(self, config: dict): + self.config = config + + def run(self, task: dict, *, dry_run: bool = True) -> dict: + readiness = _task_readiness(task) + base = { + "dry_run": dry_run, + "task_id": task.get("task_id"), + "task_type": task.get("task_type"), + "region_id": task.get("region_id"), + "provider": task.get("provider") or self.config.get("gpu", {}).get("provider", "huggingface_spaces"), + "space_name": task.get("space_name") or self.config.get("gpu", {}).get("space_name", "zeroshotGPU"), + "backend": task.get("backend") or self.config.get("gpu", {}).get("backend", "transformers"), + "model_role": task.get("model_role"), + "model_id": task.get("model_id"), + "readiness": readiness, + } + if dry_run: + return {**base, "status": "ready_dry_run" if readiness["ready"] else "blocked_missing_inputs"} + if not readiness["ready"]: + logger.warning( + "gpu_task_blocked", + extra={ + "task_id": task.get("task_id"), + "task_type": task.get("task_type"), + "missing_inputs": readiness.get("missing_inputs"), + }, + ) + return {**base, "status": "blocked_missing_inputs"} + + started = perf_counter() + client = self._client_for_task(task) + result = client.execute_task(task) + status = result.get("status", "execution_failed") + elapsed = round(perf_counter() - started, 3) + log_method = logger.info if status == "executed" else logger.warning + log_method( + "gpu_task_executed", + extra={ + "task_id": task.get("task_id"), + "task_type": task.get("task_type"), + "model_id": task.get("model_id"), + "backend": base["backend"], + "status": status, + "elapsed_seconds": elapsed, + }, + ) + return { + **base, + "status": status, + "output": result, + } + + def run_batch(self, batch: dict, *, dry_run: bool = True) -> dict: + task_results = [self.run(task, dry_run=dry_run) for task in batch.get("tasks", [])] + ready_count = sum(1 for result in task_results if result["readiness"]["ready"]) + executed_count = sum(1 for result in task_results if result["status"] == "executed") + failed_count = sum( + 1 + for result in task_results + if result["readiness"]["ready"] and result["status"] not in {"ready_dry_run", "executed"} + ) + return { + "batch_id": batch.get("batch_id"), + "task_type": batch.get("task_type"), + "provider": batch.get("provider"), + "space_name": batch.get("space_name"), + "backend": batch.get("backend"), + "model_role": batch.get("model_role"), + "model_id": batch.get("model_id"), + "task_count": len(task_results), + "ready_count": ready_count, + "blocked_count": len(task_results) - ready_count, + "executed_count": executed_count, + "failed_count": failed_count, + "status": "dry_run_complete" if dry_run else _batch_status(task_results), + "results": task_results, + } + + def _client_for_task(self, task: dict): + backend = str(task.get("backend") or self.config.get("gpu", {}).get("backend", "transformers")) + model_config = _model_config_for_task(self.config, task) + model_id = task.get("model_id") or model_config.get("model_id") + if backend == "vllm": + gpu = self.config.get("gpu", {}) + endpoint = model_config.get("endpoint") or gpu.get("vllm_endpoint") or gpu.get("endpoint") + return VLLMClient(endpoint=endpoint, model_id=model_id, api_key=gpu.get("api_key")) + return TransformersClient(model_id=model_id, model_config=model_config) + + +def _task_readiness(task: dict) -> dict: + missing: list[str] = [] + image_path = task.get("image_path") + if task.get("task_type") in {"ocr_page", "table_vlm_repair", "figure_description", "vlm_route_repair"}: + if image_path and not Path(str(image_path)).exists(): + missing.append("image_path") + if not image_path and task.get("task_type") != "vlm_route_repair": + missing.append("image_path") + if not task.get("doc_id"): + missing.append("doc_id") + if not task.get("page_nums"): + missing.append("page_nums") + return { + "ready": not missing, + "missing_inputs": missing, + "image_path_exists": bool(image_path and Path(str(image_path)).exists()), + } + + +def _model_config_for_task(config: dict, task: dict) -> dict: + if isinstance(task.get("model_config"), dict) and task["model_config"]: + return dict(task["model_config"]) + role = task.get("model_role") or "vlm" + models = config.get("gpu", {}).get("models", {}) + model_config = models.get(role, {}) if isinstance(models, dict) else {} + return dict(model_config) if isinstance(model_config, dict) else {} + + +def _batch_status(task_results: list[dict]) -> str: + if not task_results: + return "execute_complete" + if all(result["status"] == "executed" for result in task_results): + return "execute_complete" + if any(result["status"] == "executed" for result in task_results): + return "execute_partial" + if all(not result["readiness"]["ready"] for result in task_results): + return "blocked_missing_inputs" + return "execute_failed" diff --git a/zsgdp/gpu/worker_prompts.py b/zsgdp/gpu/worker_prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..192f17804d604d21f1551d16e8a9cbf04f22fb51 --- /dev/null +++ b/zsgdp/gpu/worker_prompts.py @@ -0,0 +1,27 @@ +"""Prompts for GPU task execution.""" + +from __future__ import annotations + +from typing import Any + + +def prompt_for_task(task: dict[str, Any]) -> str: + task_type = str(task.get("task_type", "unknown")) + if task_type == "ocr_page": + return "Transcribe the page faithfully. Preserve layout cues and mark uncertain text." + if task_type == "table_vlm_repair": + markdown = task.get("inputs", {}).get("markdown") or "" + return ( + "Verify and repair this table from the supplied image. Return Markdown and uncertainty notes. " + f"Candidate Markdown:\n{markdown}" + ) + if task_type == "figure_description": + caption = task.get("inputs", {}).get("caption") or "" + return ( + "Describe this figure for document understanding and downstream chunking. " + f"Existing caption: {caption}" + ) + if task_type == "vlm_route_repair": + expert = task.get("inputs", {}).get("expert", "vlm") + return f"Inspect this routed page or region for parser repair. Router expert: {expert}." + return "Run the requested document-understanding task and return structured Markdown-friendly text." diff --git a/zsgdp/logging_config.py b/zsgdp/logging_config.py new file mode 100644 index 0000000000000000000000000000000000000000..73c7ed7cea3597ddfca00cfc0cebb11bdfe5e9f0 --- /dev/null +++ b/zsgdp/logging_config.py @@ -0,0 +1,150 @@ +"""Logging configuration for zsgdp. + +Design choices (pinned): + +- Default log level is WARNING. CLI summary lines stay on stdout so users + parsing CLI output are unaffected. INFO breadcrumbs (parse start, repair + iteration, GPU task planned, etc.) go to stderr only when the operator + opts in via ZSGDP_LOG_LEVEL. +- All modules use `logging.getLogger(__name__)` so loggers are named + hierarchically (zsgdp.pipeline, zsgdp.repair.controller, etc.) and a + single root configuration governs them all. +- `configure_logging()` is idempotent: calling it twice does not duplicate + handlers. CLI entrypoints and the Gradio app both call it on startup. +- JSON formatting is opt-in via ZSGDP_LOG_JSON=1; HF Spaces logs render + structured records better in that form. Default is the human-readable + format for local terminal use. + +Environment variables: + +- ZSGDP_LOG_LEVEL (DEBUG|INFO|WARNING|ERROR|CRITICAL) — default WARNING. +- ZSGDP_LOG_JSON (1|true) — emit one-line JSON records on stderr. +""" + +from __future__ import annotations + +import json +import logging +import os +import sys +from typing import Any + +_CONFIGURED = False +_HANDLER: logging.Handler | None = None + + +def configure_logging( + *, + level: str | int | None = None, + json_format: bool | None = None, + stream: Any = None, +) -> logging.Logger: + """Configure the root zsgdp logger. Idempotent; safe to call multiple times. + + Explicit arguments override env vars; env vars override defaults. + """ + + global _CONFIGURED, _HANDLER + + resolved_level = _resolve_level(level) + resolved_json = _resolve_json(json_format) + resolved_stream = stream if stream is not None else sys.stderr + + root = logging.getLogger("zsgdp") + root.setLevel(resolved_level) + root.propagate = False + + if _HANDLER is not None: + root.removeHandler(_HANDLER) + + handler = logging.StreamHandler(resolved_stream) + handler.setLevel(resolved_level) + handler.setFormatter(_JsonFormatter() if resolved_json else _TextFormatter()) + root.addHandler(handler) + _HANDLER = handler + _CONFIGURED = True + return root + + +def get_logger(name: str | None = None) -> logging.Logger: + """Return a logger under the zsgdp namespace. + + Pass __name__ from caller modules so the logger inherits the + `zsgdp.` naming. + """ + + if name and name.startswith("zsgdp"): + return logging.getLogger(name) + return logging.getLogger(f"zsgdp.{name}" if name else "zsgdp") + + +def _resolve_level(level: str | int | None) -> int: + if level is not None: + if isinstance(level, int): + return level + return logging.getLevelName(str(level).upper()) + env_level = os.environ.get("ZSGDP_LOG_LEVEL") + if env_level: + resolved = logging.getLevelName(env_level.upper()) + if isinstance(resolved, int): + return resolved + return logging.WARNING + + +def _resolve_json(json_format: bool | None) -> bool: + if json_format is not None: + return bool(json_format) + return os.environ.get("ZSGDP_LOG_JSON", "").strip().lower() in {"1", "true", "yes"} + + +class _TextFormatter(logging.Formatter): + def __init__(self) -> None: + super().__init__("%(asctime)s %(levelname)s %(name)s %(message)s") + + +class _JsonFormatter(logging.Formatter): + """One-line JSON records keyed for HF Spaces / Datadog / Loki ingestion.""" + + _STD_KEYS = { + "name", + "msg", + "args", + "levelname", + "levelno", + "pathname", + "filename", + "module", + "exc_info", + "exc_text", + "stack_info", + "lineno", + "funcName", + "created", + "msecs", + "relativeCreated", + "thread", + "threadName", + "processName", + "process", + "asctime", + "message", + } + + def format(self, record: logging.LogRecord) -> str: # type: ignore[override] + payload: dict[str, Any] = { + "ts": self.formatTime(record, "%Y-%m-%dT%H:%M:%S%z"), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + for key, value in record.__dict__.items(): + if key in self._STD_KEYS or key.startswith("_"): + continue + try: + json.dumps(value) + except TypeError: + value = repr(value) + payload[key] = value + if record.exc_info: + payload["exc"] = self.formatException(record.exc_info) + return json.dumps(payload, ensure_ascii=False) diff --git a/zsgdp/merge/__init__.py b/zsgdp/merge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13da5ed4cb6153b1b045b8cf389ad43e50839ac5 --- /dev/null +++ b/zsgdp/merge/__init__.py @@ -0,0 +1,3 @@ +from zsgdp.merge.merge_candidates import merge_candidates + +__all__ = ["merge_candidates"] diff --git a/zsgdp/merge/align.py b/zsgdp/merge/align.py new file mode 100644 index 0000000000000000000000000000000000000000..89ef5c9e455fc91e5fa442fce5c360f3074e16d4 --- /dev/null +++ b/zsgdp/merge/align.py @@ -0,0 +1,41 @@ +"""Candidate alignment helpers.""" + +from __future__ import annotations + +import re + +from zsgdp.schema import Element +from zsgdp.utils import normalize_whitespace + + +def element_alignment_key(element: Element) -> tuple[int, str, str]: + content = normalize_element_content(element) + return (element.page_num, element_alignment_group(element), content[:240]) + + +def normalize_element_content(element: Element) -> str: + content = element.content() + content = re.sub(r"^#{1,6}\s+", "", content.strip()) + content = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", r"\1 \2", content) + content = content.replace("|", " ") + content = re.sub(r"[-:]{3,}", " ", content) + return normalize_whitespace(content).lower() + + +def element_alignment_group(element: Element) -> str: + if element.type in {"title", "heading", "paragraph", "caption", "list_item", "header", "footer"}: + return "text" + return element.type + + +def bbox_overlap_ratio(a: tuple[float, float, float, float], b: tuple[float, float, float, float]) -> float: + ax0, ay0, ax1, ay1 = a + bx0, by0, bx1, by1 = b + ix0, iy0 = max(ax0, bx0), max(ay0, by0) + ix1, iy1 = min(ax1, bx1), min(ay1, by1) + if ix1 <= ix0 or iy1 <= iy0: + return 0.0 + intersection = (ix1 - ix0) * (iy1 - iy0) + area_a = max((ax1 - ax0) * (ay1 - ay0), 1.0) + area_b = max((bx1 - bx0) * (by1 - by0), 1.0) + return intersection / min(area_a, area_b) diff --git a/zsgdp/merge/conflict_detection.py b/zsgdp/merge/conflict_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..71dfe773035cc48f7eb47c00a07c1e17c2070ae4 --- /dev/null +++ b/zsgdp/merge/conflict_detection.py @@ -0,0 +1,256 @@ +"""Lightweight parser conflict detection.""" + +from __future__ import annotations + +from itertools import combinations +import re +from statistics import mean +from typing import Any + +from zsgdp.schema import Element, ParseCandidate, QualityIssue, TableObject + +TEXT_GAP_THRESHOLD = 0.55 +COUNT_GAP_THRESHOLD = 0.50 +READING_ORDER_LCS_THRESHOLD = 0.65 + + +def detect_candidate_conflicts(candidates: list[ParseCandidate]) -> list[QualityIssue]: + report = build_candidate_conflict_report(candidates) + issues: list[QualityIssue] = [] + for conflict in report["conflicts"]: + issues.append( + QualityIssue( + issue_type="parser_disagreement", + severity=conflict.get("severity", "warning"), + message=conflict["message"], + page_num=conflict.get("page_num"), + region_id=conflict.get("region_id"), + blocking=False, + metadata={key: value for key, value in conflict.items() if key not in {"message", "severity"}}, + ) + ) + return issues + + +def build_candidate_conflict_report(candidates: list[ParseCandidate]) -> dict[str, Any]: + summaries = [_candidate_summary(candidate) for candidate in candidates] + conflicts: list[dict[str, Any]] = [] + if len(candidates) >= 2: + conflicts.extend(_text_coverage_conflicts(summaries)) + conflicts.extend(_count_conflicts(summaries, "table_count", "table objects")) + conflicts.extend(_count_conflicts(summaries, "figure_count", "figure objects")) + conflicts.extend(_reading_order_conflicts(candidates)) + conflicts.extend(_table_structure_conflicts(candidates)) + + return { + "schema_version": 1, + "candidate_count": len(candidates), + "parser_summaries": summaries, + "conflict_count": len(conflicts), + "conflicts": conflicts, + } + + +def _candidate_summary(candidate: ParseCandidate) -> dict[str, Any]: + text_lengths = [len(element.content()) for element in candidate.elements] + return { + "parser": candidate.parser_name, + "confidence": candidate.confidence, + "page_count": len(candidate.pages), + "element_count": len(candidate.elements), + "table_count": len(candidate.tables), + "figure_count": len(candidate.figures), + "text_chars": sum(text_lengths), + "mean_element_chars": mean(text_lengths) if text_lengths else 0.0, + } + + +def _text_coverage_conflicts(summaries: list[dict[str, Any]]) -> list[dict[str, Any]]: + text_lengths = {summary["parser"]: int(summary["text_chars"]) for summary in summaries} + if not text_lengths: + return [] + max_parser = max(text_lengths, key=text_lengths.get) + min_parser = min(text_lengths, key=text_lengths.get) + max_len = max(text_lengths[max_parser], 1) + min_len = text_lengths[min_parser] + ratio = min_len / max_len + if ratio >= TEXT_GAP_THRESHOLD: + return [] + return [ + { + "conflict_id": "pc1", + "type": "text_coverage_gap", + "severity": "warning", + "message": ( + f"Parser text coverage differs strongly: {min_parser}={min_len} chars, " + f"{max_parser}={max_len} chars." + ), + "parsers": [min_parser, max_parser], + "ratio": ratio, + "text_lengths": text_lengths, + } + ] + + +def _count_conflicts(summaries: list[dict[str, Any]], key: str, label: str) -> list[dict[str, Any]]: + counts = {summary["parser"]: int(summary[key]) for summary in summaries} + max_parser = max(counts, key=counts.get) + min_parser = min(counts, key=counts.get) + max_count = counts[max_parser] + min_count = counts[min_parser] + if max_count == 0: + return [] + ratio = min_count / max_count + if ratio >= COUNT_GAP_THRESHOLD: + return [] + return [ + { + "conflict_id": f"pc_{key}", + "type": f"{key}_disagreement", + "severity": "warning", + "message": ( + f"Parser {label} counts differ strongly: " + f"{min_parser}={min_count}, {max_parser}={max_count}." + ), + "parsers": [min_parser, max_parser], + "ratio": ratio, + "counts": counts, + } + ] + + +def _reading_order_conflicts(candidates: list[ParseCandidate]) -> list[dict[str, Any]]: + conflicts: list[dict[str, Any]] = [] + for left, right in combinations(candidates, 2): + left_pages = _element_sequences_by_page(left.elements) + right_pages = _element_sequences_by_page(right.elements) + for page_num in sorted(set(left_pages) & set(right_pages)): + left_sequence = left_pages[page_num] + right_sequence = right_pages[page_num] + common = set(left_sequence) & set(right_sequence) + if len(common) < 3: + continue + left_common = [key for key in left_sequence if key in common] + right_common = [key for key in right_sequence if key in common] + order_similarity = _lcs_length(left_common, right_common) / len(common) + if order_similarity >= READING_ORDER_LCS_THRESHOLD: + continue + conflicts.append( + { + "conflict_id": f"pc_reading_{len(conflicts) + 1}", + "type": "reading_order_disagreement", + "severity": "warning", + "message": ( + f"Parsers disagree on reading order for page {page_num}: " + f"{left.parser_name} vs {right.parser_name}." + ), + "page_num": page_num, + "parsers": [left.parser_name, right.parser_name], + "order_similarity": order_similarity, + "common_element_count": len(common), + } + ) + return conflicts + + +def _table_structure_conflicts(candidates: list[ParseCandidate]) -> list[dict[str, Any]]: + conflicts: list[dict[str, Any]] = [] + for left, right in combinations(candidates, 2): + for left_table in left.tables: + for right_table in right.tables: + if not set(left_table.page_nums or []) & set(right_table.page_nums or []): + continue + overlap = _table_token_overlap(left_table, right_table) + if overlap < 0.50: + continue + left_shape = _table_shape(left_table) + right_shape = _table_shape(right_table) + if left_shape == right_shape: + continue + page_num = min(set(left_table.page_nums or [1]) & set(right_table.page_nums or [1]) or {1}) + conflicts.append( + { + "conflict_id": f"pc_table_{len(conflicts) + 1}", + "type": "table_structure_disagreement", + "severity": "warning", + "message": ( + f"Parsers disagree on table structure on page {page_num}: " + f"{left_table.table_id} shape={left_shape}, {right_table.table_id} shape={right_shape}." + ), + "page_num": page_num, + "region_id": f"{left_table.table_id}|{right_table.table_id}", + "parsers": [left.parser_name, right.parser_name], + "table_ids": [left_table.table_id, right_table.table_id], + "left_shape": {"rows": left_shape[0], "columns": left_shape[1]}, + "right_shape": {"rows": right_shape[0], "columns": right_shape[1]}, + "token_overlap": overlap, + } + ) + return conflicts + + +def _element_sequences_by_page(elements: list[Element]) -> dict[int, list[str]]: + pages: dict[int, list[Element]] = {} + for element in elements: + pages.setdefault(element.page_num, []).append(element) + return { + page_num: [ + key + for key in (_content_key(element.content()) for element in sorted(page_elements, key=lambda item: item.reading_order or 0)) + if key + ] + for page_num, page_elements in pages.items() + } + + +def _content_key(text: str) -> str: + tokens = re.findall(r"[a-z0-9]+", text.lower()) + if not tokens: + return "" + return " ".join(tokens[:8]) + + +def _lcs_length(left: list[str], right: list[str]) -> int: + previous = [0] * (len(right) + 1) + for left_item in left: + current = [0] + for index, right_item in enumerate(right, start=1): + if left_item == right_item: + current.append(previous[index - 1] + 1) + else: + current.append(max(previous[index], current[-1])) + previous = current + return previous[-1] + + +def _table_shape(table: TableObject) -> tuple[int, int]: + rows = _table_rows(table.markdown or "") + if not rows: + return (0, 0) + return (len(rows), max(len(row) for row in rows)) + + +def _table_rows(markdown: str) -> list[list[str]]: + rows: list[list[str]] = [] + for line in markdown.splitlines(): + stripped = line.strip() + if not (stripped.startswith("|") and stripped.endswith("|")): + continue + cells = [cell.strip() for cell in stripped.strip("|").split("|")] + if cells and all(re.fullmatch(r":?-{3,}:?", cell) for cell in cells if cell): + continue + rows.append(cells) + return rows + + +def _table_token_overlap(left: TableObject, right: TableObject) -> float: + left_tokens = _table_tokens(left) + right_tokens = _table_tokens(right) + if min(len(left_tokens), len(right_tokens)) < 2: + return 0.0 + return len(left_tokens & right_tokens) / min(len(left_tokens), len(right_tokens)) + + +def _table_tokens(table: TableObject) -> set[str]: + source = table.markdown or table.html or table.natural_language_rendering or table.caption or "" + return set(re.findall(r"[a-z0-9]+", source.lower())) diff --git a/zsgdp/merge/dedupe.py b/zsgdp/merge/dedupe.py new file mode 100644 index 0000000000000000000000000000000000000000..96054741ab2e6231c55081472a0f2405a99b98d2 --- /dev/null +++ b/zsgdp/merge/dedupe.py @@ -0,0 +1,226 @@ +"""Deduplication helpers.""" + +from __future__ import annotations + +import re + +from zsgdp.merge.align import element_alignment_key, normalize_element_content +from zsgdp.schema import Element, FigureObject, TableObject + +STRUCTURE_RANK = { + "title": 5, + "heading": 4, + "caption": 3, + "list_item": 2, + "paragraph": 1, + "header": 0, + "footer": 0, +} + + +def dedupe_elements(elements: list[Element]) -> list[Element]: + winners: dict[tuple[int, str, str], Element] = {} + for element in elements: + key = element_alignment_key(element) + current = winners.get(key) + if current is None: + winners[key] = element + continue + winner = _choose_element(current, element) + loser = element if winner is current else current + _merge_element_metadata(winner, loser) + winners[key] = winner + return sorted( + _drop_table_text_duplicates(list(winners.values())), + key=lambda item: (item.page_num, item.reading_order or 0, item.element_id), + ) + + +def dedupe_tables(tables: list[TableObject]) -> list[TableObject]: + winners: list[TableObject] = [] + for table in tables: + current_index = next((index for index, current in enumerate(winners) if _same_table(table, current)), None) + if current_index is None: + winners.append(table) + continue + + current = winners[current_index] + winner = _choose_table(current, table) + loser = table if winner is current else current + _merge_table_metadata(winner, loser) + winners[current_index] = winner + return sorted(winners, key=lambda item: (min(item.page_nums or [0]), item.table_id)) + + +def dedupe_figures(figures: list[FigureObject]) -> list[FigureObject]: + winners: dict[tuple[int, str], FigureObject] = {} + for figure in figures: + bbox_key = ",".join(f"{value:.1f}" for value in figure.bbox) if figure.bbox else figure.figure_id + key = (figure.page_num, bbox_key) + current = winners.get(key) + if current is None or figure.confidence > current.confidence: + winners[key] = figure + return sorted(winners.values(), key=lambda item: (item.page_num, item.figure_id)) + + +def _choose_element(a: Element, b: Element) -> Element: + if a.type == "table" and b.type == "table": + return b if _table_element_score(b) > _table_element_score(a) else a + a_score = _element_score(a) + b_score = _element_score(b) + return b if b_score > a_score else a + + +def _element_score(element: Element) -> tuple[float, int, int, int]: + confidence = element.confidence or 0.0 + structure = STRUCTURE_RANK.get(element.type, 1) + has_markdown = 1 if element.markdown else 0 + has_bbox = 1 if element.bbox else 0 + return (confidence, structure, has_markdown, has_bbox) + + +def _table_element_score(element: Element) -> tuple[int, int, int, float]: + rows = _table_rows(element.content()) + data_rows = max(0, len(rows) - 1) + cell_count = sum(1 for row in rows for cell in row if cell) + has_bbox = 1 if element.bbox else 0 + return (data_rows, cell_count, has_bbox, element.confidence or 0.0) + + +def _merge_element_metadata(winner: Element, loser: Element) -> None: + if winner.bbox is None and loser.bbox is not None: + winner.bbox = loser.bbox + winner.provenance["bbox_source_parser"] = loser.source_parser + if winner.html is None and loser.html is not None: + winner.html = loser.html + if winner.markdown is None and loser.markdown is not None and winner.type in {"table", "heading", "title", "list_item"}: + winner.markdown = loser.markdown + if winner.text is None and loser.text is not None and winner.type != "table": + winner.text = loser.text + if winner.reading_order is None and loser.reading_order is not None: + winner.reading_order = loser.reading_order + if winner.type == "table" and loser.confidence is not None: + winner.confidence = max(winner.confidence or 0.0, loser.confidence) + merged_from = winner.provenance.setdefault("merged_from", []) + if isinstance(merged_from, list): + merged_from.append( + { + "element_id": loser.element_id, + "source_parser": loser.source_parser, + "type": loser.type, + "confidence": loser.confidence, + } + ) + + +def _same_table(a: TableObject, b: TableObject) -> bool: + if not set(a.page_nums or []) & set(b.page_nums or []): + return False + + a_text = _table_text(a) + b_text = _table_text(b) + if a_text and b_text and a_text == b_text: + return True + + a_tokens = _table_tokens(a) + b_tokens = _table_tokens(b) + if min(len(a_tokens), len(b_tokens)) < 4: + return False + overlap = len(a_tokens & b_tokens) / min(len(a_tokens), len(b_tokens)) + return overlap >= 0.8 + + +def _choose_table(a: TableObject, b: TableObject) -> TableObject: + return b if _table_score(b) > _table_score(a) else a + + +def _table_score(table: TableObject) -> tuple[int, int, int, float]: + rows = _table_rows(table.markdown or "") + data_rows = max(0, len(rows) - 1) + cell_count = sum(1 for row in rows for cell in row if cell) + has_bbox = 1 if table.bbox else 0 + return (data_rows, cell_count, has_bbox, table.confidence) + + +def _merge_table_metadata(winner: TableObject, loser: TableObject) -> None: + if winner.bbox is None and loser.bbox is not None: + winner.bbox = loser.bbox + winner.provenance["bbox_source_parser"] = loser.source_parser + if winner.html is None and loser.html is not None: + winner.html = loser.html + if winner.markdown is None and loser.markdown is not None: + winner.markdown = loser.markdown + if winner.dataframe_json is None and loser.dataframe_json is not None: + winner.dataframe_json = loser.dataframe_json + if winner.natural_language_rendering is None and loser.natural_language_rendering is not None: + winner.natural_language_rendering = loser.natural_language_rendering + if winner.caption is None and loser.caption is not None: + winner.caption = loser.caption + + crop_path = loser.provenance.get("crop_path") + if crop_path and "crop_path" not in winner.provenance: + winner.provenance["crop_path"] = crop_path + + winner.confidence = max(winner.confidence, loser.confidence) + source_parsers = winner.provenance.setdefault("source_parsers", [winner.source_parser]) + if isinstance(source_parsers, list) and loser.source_parser not in source_parsers: + source_parsers.append(loser.source_parser) + + merged_from = winner.provenance.setdefault("merged_from", []) + if isinstance(merged_from, list): + merged_from.append( + { + "table_id": loser.table_id, + "source_parser": loser.source_parser, + "confidence": loser.confidence, + } + ) + + +def _table_text(table: TableObject) -> str: + source = table.markdown or table.html or table.natural_language_rendering or table.caption or "" + return " ".join(source.lower().split()) + + +def _table_tokens(table: TableObject) -> set[str]: + return set(re.findall(r"[a-z0-9]+", _table_text(table))) + + +def _table_rows(markdown: str) -> list[list[str]]: + rows: list[list[str]] = [] + for line in markdown.splitlines(): + stripped = line.strip() + if not (stripped.startswith("|") and stripped.endswith("|")): + continue + cells = [cell.strip() for cell in stripped.strip("|").split("|")] + if cells and all(re.fullmatch(r":?-{3,}:?", cell) for cell in cells if cell): + continue + rows.append(cells) + return rows + + +def _drop_table_text_duplicates(elements: list[Element]) -> list[Element]: + table_text_by_page: dict[int, list[str]] = {} + for element in elements: + if element.type == "table": + text = normalize_element_content(element) + if text: + table_text_by_page.setdefault(element.page_num, []).append(text) + + filtered: list[Element] = [] + for element in elements: + if element.type == "paragraph": + text = normalize_element_content(element) + if any(_same_table_text(text, table_text) for table_text in table_text_by_page.get(element.page_num, [])): + continue + filtered.append(element) + return filtered + + +def _same_table_text(text: str, table_text: str) -> bool: + if not text or not table_text: + return False + if text == table_text: + return True + shorter, longer = sorted([text, table_text], key=len) + return len(shorter) >= 20 and shorter in longer diff --git a/zsgdp/merge/merge_candidates.py b/zsgdp/merge/merge_candidates.py new file mode 100644 index 0000000000000000000000000000000000000000..56dd736d78cad4f038595ae483b205d70e743939 --- /dev/null +++ b/zsgdp/merge/merge_candidates.py @@ -0,0 +1,46 @@ +"""Merge parser candidates into a canonical ParsedDocument.""" + +from __future__ import annotations + +from zsgdp.merge.conflict_detection import build_candidate_conflict_report, detect_candidate_conflicts +from zsgdp.merge.dedupe import dedupe_elements, dedupe_figures, dedupe_tables +from zsgdp.schema import DocumentProfile, ParseCandidate, ParsedDocument, QualityReport + + +def merge_candidates(candidates: list[ParseCandidate], profile: DocumentProfile) -> ParsedDocument: + pages_by_num: dict[int, dict] = { + page.page_num: {"page_num": page.page_num, "profile": page.to_dict()} for page in profile.pages + } + elements = [] + tables = [] + figures = [] + for candidate in candidates: + for page in candidate.pages: + page_num = int(page.get("page_num", 1)) + pages_by_num.setdefault(page_num, {"page_num": page_num}) + pages_by_num[page_num].setdefault("parser_pages", []).append(page) + elements.extend(candidate.elements) + tables.extend(candidate.tables) + figures.extend(candidate.figures) + + conflict_report = build_candidate_conflict_report(candidates) + quality = QualityReport(score=1.0) + quality.issues.extend(detect_candidate_conflicts(candidates)) + + return ParsedDocument( + doc_id=profile.doc_id, + source_path=profile.source_path, + file_type=profile.file_type, + pages=[pages_by_num[key] for key in sorted(pages_by_num)], + elements=dedupe_elements(elements), + tables=dedupe_tables(tables), + figures=dedupe_figures(figures), + quality_report=quality, + provenance={ + "candidate_parsers": [candidate.parser_name for candidate in candidates], + "candidate_confidences": { + candidate.parser_name: candidate.confidence for candidate in candidates + }, + "conflict_report": conflict_report, + }, + ) diff --git a/zsgdp/normalize/__init__.py b/zsgdp/normalize/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c85a40a90c051c4c3231501c4a831529a32da562 --- /dev/null +++ b/zsgdp/normalize/__init__.py @@ -0,0 +1,5 @@ +"""Parser normalization package.""" + +from zsgdp.normalize.markdown import markdown_to_blocks, normalize_markdown_candidate, normalize_markdown_table + +__all__ = ["markdown_to_blocks", "normalize_markdown_candidate", "normalize_markdown_table"] diff --git a/zsgdp/normalize/markdown.py b/zsgdp/normalize/markdown.py new file mode 100644 index 0000000000000000000000000000000000000000..120661bf1ba66ab852db32cf5f97b9cd43145a00 --- /dev/null +++ b/zsgdp/normalize/markdown.py @@ -0,0 +1,245 @@ +"""Markdown-to-schema normalization helpers.""" + +from __future__ import annotations + +from dataclasses import dataclass +import re + +from zsgdp.schema import Element, FigureObject, ParseCandidate, TableObject + +PAGE_COMMENT_RE = re.compile(r"", re.IGNORECASE) +PAGE_HEADING_RE = re.compile(r"^\s*(?:-{3,}\s*)?page\s+(\d+)(?:\s*-{3,})?\s*$", re.IGNORECASE) +IMAGE_RE = re.compile(r"!\[(?P[^\]]*)\]\((?P[^)]+)\)") +CAPTION_RE = re.compile(r"^(figure|fig\.|table|chart|exhibit)\s+\d*", re.IGNORECASE) +PLAIN_TABLE_SPLIT_RE = re.compile(r"\t+|\s{2,}") + + +@dataclass(slots=True) +class MarkdownBlock: + block_type: str + text: str + page_num: int + metadata: dict + + +def normalize_markdown_candidate( + *, + markdown: str, + doc_id: str, + source_path: str, + file_type: str, + parser_name: str, + confidence: float = 0.86, + provenance: dict | None = None, +) -> ParseCandidate: + """Normalize Markdown into the shared parse candidate schema.""" + + blocks = markdown_to_blocks(markdown) + elements: list[Element] = [] + tables: list[TableObject] = [] + figures: list[FigureObject] = [] + reading_counts: dict[int, int] = {} + table_count = 0 + figure_count = 0 + + for block in blocks: + reading_counts[block.page_num] = reading_counts.get(block.page_num, 0) + 1 + order = reading_counts[block.page_num] + element_id = f"{parser_name}_p{block.page_num}_e{order}" + markdown_value = block.text if block.block_type in {"title", "heading", "table", "list_item", "image"} else None + text_value = None if block.block_type == "table" else block.text + + if block.block_type == "table": + table_count += 1 + table_id = f"{parser_name}_t{table_count}" + tables.append( + TableObject( + table_id=table_id, + page_nums=[block.page_num], + markdown=block.text, + confidence=min(confidence, 0.84), + source_parser=parser_name, + provenance={"element_id": element_id, **block.metadata}, + ) + ) + + if block.block_type == "image": + figure_count += 1 + match = IMAGE_RE.search(block.text) + image_path = match.group("src") if match else None + caption = match.group("alt") if match and match.group("alt") else None + figures.append( + FigureObject( + figure_id=f"{parser_name}_f{figure_count}", + page_num=block.page_num, + image_path=image_path, + caption=caption, + confidence=min(confidence, 0.72), + source_parser=parser_name, + provenance={"element_id": element_id, **block.metadata}, + ) + ) + + elements.append( + Element( + element_id=element_id, + doc_id=doc_id, + page_num=block.page_num, + type=block.block_type, + text=text_value, + markdown=markdown_value, + reading_order=order, + confidence=confidence, + source_parser=parser_name, + provenance={"parser": parser_name, **block.metadata}, + ) + ) + + page_nums = sorted({block.page_num for block in blocks} or {1}) + return ParseCandidate( + parser_name=parser_name, + doc_id=doc_id, + source_path=source_path, + file_type=file_type, + pages=[{"page_num": page_num, "source_parser": parser_name} for page_num in page_nums], + elements=elements, + tables=tables, + figures=figures, + confidence=confidence if elements else 0.25, + provenance={"parser": parser_name, "normalizer": "markdown", **(provenance or {})}, + ) + + +def markdown_to_blocks(markdown: str) -> list[MarkdownBlock]: + page_num = 1 + blocks: list[MarkdownBlock] = [] + paragraph: list[str] = [] + table: list[str] = [] + + def flush_paragraph() -> None: + nonlocal paragraph + if paragraph: + text = " ".join(line.strip() for line in paragraph if line.strip()).strip() + if text: + blocks.append(_make_block(_classify_text_block(text), text, page_num, {})) + paragraph = [] + + def flush_table() -> None: + nonlocal table + if table: + normalized = normalize_markdown_table("\n".join(table)) + if normalized: + blocks.append(_make_block("table", normalized, page_num, {"table_line_count": len(table)})) + table = [] + + for raw_line in markdown.splitlines(): + line = raw_line.rstrip() + stripped = line.strip() + page_match = PAGE_COMMENT_RE.match(stripped) or PAGE_HEADING_RE.match(stripped) + if page_match: + flush_paragraph() + flush_table() + page_num = int(page_match.group(1)) + continue + + if not stripped: + flush_paragraph() + flush_table() + continue + + if _is_table_line(stripped): + flush_paragraph() + table.append(stripped) + continue + + flush_table() + if stripped.startswith("#"): + flush_paragraph() + heading_level = len(stripped) - len(stripped.lstrip("#")) + blocks.append( + _make_block( + "title" if heading_level == 1 else "heading", + stripped, + page_num, + {"heading_level": heading_level}, + ) + ) + elif IMAGE_RE.search(stripped): + flush_paragraph() + blocks.append(_make_block("image", stripped, page_num, {})) + elif stripped.startswith(("- ", "* ", "+ ")) or re.match(r"^\d+[.)]\s+", stripped): + flush_paragraph() + blocks.append(_make_block("list_item", stripped, page_num, {})) + else: + paragraph.append(stripped) + + flush_paragraph() + flush_table() + return blocks + + +def normalize_markdown_table(table_markdown: str) -> str: + rows = [_split_table_row(row) for row in table_markdown.splitlines() if row.strip()] + rows = [row for row in rows if row] + rows = [row for row in rows if not _is_separator_row(row)] + if not rows: + return "" + width = max(len(row) for row in rows) + padded = [row + [""] * (width - len(row)) for row in rows] + lines = [_markdown_row(padded[0]), _markdown_row(["---"] * width)] + lines.extend(_markdown_row(row) for row in padded[1:]) + return "\n".join(lines) + + +def _make_block(block_type: str, text: str, page_num: int, metadata: dict) -> MarkdownBlock: + return MarkdownBlock(block_type=block_type, text=text, page_num=page_num, metadata=metadata) + + +def _classify_text_block(text: str) -> str: + if CAPTION_RE.match(text): + return "caption" + return "paragraph" + + +def _is_markdown_table_line(line: str) -> bool: + return line.startswith("|") and line.endswith("|") and line.count("|") >= 2 + + +def _is_table_line(line: str) -> bool: + return _is_markdown_table_line(line) or _is_plain_table_line(line) + + +def _is_plain_table_line(line: str) -> bool: + if "|" in line or "\t" in line: + return "\t" in line + cells = PLAIN_TABLE_SPLIT_RE.split(line.strip()) + if len(cells) < 2: + return False + # Header rows often have no digits, so allow them only while a table block is forming. + return any(any(char.isdigit() for char in cell) for cell in cells) or len(cells) >= 3 + + +def _split_table_row(row: str) -> list[str]: + if _is_markdown_table_line(row): + return _split_markdown_row(row) + return [_clean_cell(cell) for cell in PLAIN_TABLE_SPLIT_RE.split(row.strip()) if cell.strip()] + + +def _split_markdown_row(row: str) -> list[str]: + return [_clean_cell(cell) for cell in row.strip().strip("|").split("|")] + + +def _is_separator_row(row: list[str]) -> bool: + return all(re.fullmatch(r":?-{3,}:?", cell.strip()) for cell in row if cell.strip()) + + +def _markdown_row(cells: list[str]) -> str: + return "| " + " | ".join(_escape_cell(cell) for cell in cells) + " |" + + +def _clean_cell(cell: str) -> str: + return " ".join(cell.strip().split()) + + +def _escape_cell(cell: str) -> str: + return cell.replace("|", "\\|") diff --git a/zsgdp/normalize/normalize_docling.py b/zsgdp/normalize/normalize_docling.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd06de2a2e6519e1fc086f88894f1214035999d --- /dev/null +++ b/zsgdp/normalize/normalize_docling.py @@ -0,0 +1,7 @@ +"""Docling normalization helpers.""" + +from __future__ import annotations + +from zsgdp.normalize.markdown import normalize_markdown_candidate + +__all__ = ["normalize_markdown_candidate"] diff --git a/zsgdp/normalize/normalize_marker.py b/zsgdp/normalize/normalize_marker.py new file mode 100644 index 0000000000000000000000000000000000000000..a301e1398263b07e3a6848d618b5f5e518b55dc6 --- /dev/null +++ b/zsgdp/normalize/normalize_marker.py @@ -0,0 +1,18 @@ +"""Normalize Marker Markdown into the canonical parse schema.""" + +from __future__ import annotations + +from zsgdp.normalize.markdown import normalize_markdown_candidate +from zsgdp.schema import DocumentProfile, ParseCandidate + + +def normalize_marker_markdown(*, markdown: str, profile: DocumentProfile, source_path: str) -> ParseCandidate: + return normalize_markdown_candidate( + markdown=markdown, + doc_id=profile.doc_id, + source_path=source_path, + file_type=profile.file_type, + parser_name="marker", + confidence=0.84, + provenance={"backend": "marker", "normalizer": "normalize_marker_markdown"}, + ) diff --git a/zsgdp/normalize/normalize_mineru.py b/zsgdp/normalize/normalize_mineru.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc65b9385740c5cb4b4f41b3b01c612ccd3b7c0 --- /dev/null +++ b/zsgdp/normalize/normalize_mineru.py @@ -0,0 +1,18 @@ +"""MinerU Markdown normalization.""" + +from __future__ import annotations + +from zsgdp.normalize.markdown import normalize_markdown_candidate +from zsgdp.schema import DocumentProfile, ParseCandidate + + +def normalize_mineru_markdown(*, markdown: str, profile: DocumentProfile, source_path: str) -> ParseCandidate: + return normalize_markdown_candidate( + markdown=markdown, + doc_id=profile.doc_id, + source_path=source_path, + file_type=profile.file_type, + parser_name="mineru", + confidence=0.82, + provenance={"backend": "mineru"}, + ) diff --git a/zsgdp/normalize/normalize_olmocr.py b/zsgdp/normalize/normalize_olmocr.py new file mode 100644 index 0000000000000000000000000000000000000000..5f7667713f38928d532fa1b353d18ab32e8f1315 --- /dev/null +++ b/zsgdp/normalize/normalize_olmocr.py @@ -0,0 +1,18 @@ +"""olmOCR Markdown normalization.""" + +from __future__ import annotations + +from zsgdp.normalize.markdown import normalize_markdown_candidate +from zsgdp.schema import DocumentProfile, ParseCandidate + + +def normalize_olmocr_markdown(*, markdown: str, profile: DocumentProfile, source_path: str) -> ParseCandidate: + return normalize_markdown_candidate( + markdown=markdown, + doc_id=profile.doc_id, + source_path=source_path, + file_type=profile.file_type, + parser_name="olmocr", + confidence=0.80, + provenance={"backend": "olmocr"}, + ) diff --git a/zsgdp/normalize/normalize_paddleocr.py b/zsgdp/normalize/normalize_paddleocr.py new file mode 100644 index 0000000000000000000000000000000000000000..bef855df0ab0e33acc0b75c9a739f65459bf854a --- /dev/null +++ b/zsgdp/normalize/normalize_paddleocr.py @@ -0,0 +1,18 @@ +"""PaddleOCR Markdown normalization.""" + +from __future__ import annotations + +from zsgdp.normalize.markdown import normalize_markdown_candidate +from zsgdp.schema import DocumentProfile, ParseCandidate + + +def normalize_paddleocr_markdown(*, markdown: str, profile: DocumentProfile, source_path: str) -> ParseCandidate: + return normalize_markdown_candidate( + markdown=markdown, + doc_id=profile.doc_id, + source_path=source_path, + file_type=profile.file_type, + parser_name="paddleocr", + confidence=0.78, + provenance={"backend": "paddleocr"}, + ) diff --git a/zsgdp/normalize/normalize_unstructured.py b/zsgdp/normalize/normalize_unstructured.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac5fc600df51306e7851b0540f0015df902f1bb --- /dev/null +++ b/zsgdp/normalize/normalize_unstructured.py @@ -0,0 +1,80 @@ +"""Normalize Unstructured partition elements into the canonical parse schema.""" + +from __future__ import annotations + +from typing import Any, Iterable + +from zsgdp.normalize.markdown import normalize_markdown_candidate +from zsgdp.schema import DocumentProfile, ParseCandidate + + +def normalize_unstructured_parts( + *, + parts: Iterable[Any], + profile: DocumentProfile, + source_path: str, +) -> ParseCandidate: + markdown = _parts_to_markdown(parts) + candidate = normalize_markdown_candidate( + markdown=markdown, + doc_id=profile.doc_id, + source_path=source_path, + file_type=profile.file_type, + parser_name="unstructured", + confidence=0.78, + provenance={"backend": "unstructured", "normalizer": "normalize_unstructured_parts"}, + ) + return candidate + + +def _parts_to_markdown(parts: Iterable[Any]) -> str: + lines: list[str] = [] + current_page: int | None = None + for part in parts: + text = str(part).strip() + if not text: + continue + page_num = _part_page_num(part) + if page_num and page_num != current_page: + current_page = page_num + if lines: + lines.append("") + lines.append(f"") + lines.append("") + lines.append(_part_to_markdown(part, text)) + lines.append("") + return "\n".join(lines).strip() + ("\n" if lines else "") + + +def _part_to_markdown(part: Any, text: str) -> str: + category = _part_category(part).lower() + if category in {"title", "header"} and not text.startswith("#"): + return f"# {text}" + if category in {"table"}: + html = _part_metadata_value(part, "text_as_html") + return str(html).strip() if html else text + return text + + +def _part_category(part: Any) -> str: + category = getattr(part, "category", None) + if category: + return str(category) + return part.__class__.__name__ + + +def _part_page_num(part: Any) -> int | None: + value = _part_metadata_value(part, "page_number") + try: + return int(value) if value is not None else None + except (TypeError, ValueError): + return None + + +def _part_metadata_value(part: Any, key: str) -> Any: + metadata = getattr(part, "metadata", None) + if metadata is None: + return None + if isinstance(metadata, dict): + return metadata.get(key) + return getattr(metadata, key, None) diff --git a/zsgdp/parsers/__init__.py b/zsgdp/parsers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7626a4147a5b1dbf7d650485f05b7a772dbc31b1 --- /dev/null +++ b/zsgdp/parsers/__init__.py @@ -0,0 +1,3 @@ +from zsgdp.parsers.registry import get_parser, parser_names + +__all__ = ["get_parser", "parser_names"] diff --git a/zsgdp/parsers/base.py b/zsgdp/parsers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3101caf6846a050ba64bc2b5eb8130445b6348e9 --- /dev/null +++ b/zsgdp/parsers/base.py @@ -0,0 +1,37 @@ +"""Parser adapter base classes.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from zsgdp.schema import DocumentProfile, ParseCandidate + + +class ParserUnavailableError(RuntimeError): + """Raised when a parser backend is not installed or cannot run.""" + + +class ParseError(RuntimeError): + """Raised when a parser backend fails after being selected.""" + + +class BaseParser: + name = "base" + supported_file_types: set[str] = set() + + def available(self) -> bool: + return True + + def supports(self, file_type: str) -> bool: + return not self.supported_file_types or file_type in self.supported_file_types + + def parse( + self, + path: str | Path, + profile: DocumentProfile, + config: dict[str, Any], + *, + pages: list[int] | None = None, + ) -> ParseCandidate: + raise NotImplementedError diff --git a/zsgdp/parsers/docling_parser.py b/zsgdp/parsers/docling_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f626c53a5c4abe6aff7c56cd16ac855d7ea8445c --- /dev/null +++ b/zsgdp/parsers/docling_parser.py @@ -0,0 +1,150 @@ +"""Docling parser adapter.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from zsgdp.normalize.markdown import normalize_markdown_candidate +from zsgdp.parsers.base import BaseParser, ParseError, ParserUnavailableError +from zsgdp.schema import DocumentProfile, ParseCandidate + + +class DoclingParser(BaseParser): + name = "docling" + supported_file_types = {"pdf", "docx", "pptx", "xlsx", "html", "image"} + + def available(self) -> bool: + try: + import docling.document_converter # type: ignore # noqa: F401 + except ImportError: + return False + return True + + def parse( + self, + path: str | Path, + profile: DocumentProfile, + config: dict[str, Any], + *, + pages: list[int] | None = None, + ) -> ParseCandidate: + try: + from docling.document_converter import DocumentConverter # type: ignore + except ImportError as exc: + raise ParserUnavailableError("Docling is not installed.") from exc + + try: + converter = _build_converter(DocumentConverter, profile.file_type, config.get("parsers", {}).get("docling", {})) + result = converter.convert(str(path)) + document = getattr(result, "document", result) + markdown = _export_markdown(document) + provenance = _docling_provenance(result, document) + except Exception as exc: + raise ParseError(f"Docling failed to parse {path}: {exc}") from exc + + candidate = normalize_docling_markdown( + markdown=markdown, + profile=profile, + source_path=str(path), + confidence=0.88, + provenance=provenance, + ) + candidate.pages = _page_records(candidate, profile) + return candidate + + +def _build_converter(document_converter_cls: Any, file_type: str, parser_config: dict[str, Any]) -> Any: + if file_type != "pdf": + return document_converter_cls() + + try: + from docling.datamodel.base_models import InputFormat # type: ignore + from docling.datamodel.pipeline_options import PdfPipelineOptions # type: ignore + from docling.document_converter import PdfFormatOption # type: ignore + except ImportError: + return document_converter_cls() + + pipeline_options = PdfPipelineOptions( + do_ocr=bool(parser_config.get("do_ocr", False)), + do_table_structure=bool(parser_config.get("do_table_structure", False)), + force_backend_text=bool(parser_config.get("force_backend_text", True)), + generate_page_images=bool(parser_config.get("generate_page_images", False)), + generate_picture_images=bool(parser_config.get("generate_picture_images", False)), + generate_table_images=bool(parser_config.get("generate_table_images", False)), + do_picture_description=bool(parser_config.get("do_picture_description", False)), + do_picture_classification=bool(parser_config.get("do_picture_classification", False)), + do_formula_enrichment=bool(parser_config.get("do_formula_enrichment", False)), + do_code_enrichment=bool(parser_config.get("do_code_enrichment", False)), + ) + return document_converter_cls( + format_options={ + InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options), + } + ) + + +def normalize_docling_markdown( + *, + markdown: str, + profile: DocumentProfile, + source_path: str, + confidence: float = 0.88, + provenance: dict | None = None, +) -> ParseCandidate: + candidate = normalize_markdown_candidate( + markdown=markdown, + doc_id=profile.doc_id, + source_path=source_path, + file_type=profile.file_type, + parser_name="docling", + confidence=confidence, + provenance={"docling_export": "markdown", **(provenance or {})}, + ) + known_pages = {page.page_num for page in profile.pages} + for element in candidate.elements: + if element.page_num not in known_pages and known_pages: + element.provenance["page_warning"] = "Page marker not found in profile." + return candidate + + +def _export_markdown(document: Any) -> str: + for method_name in ("export_to_markdown", "export_to_md", "to_markdown"): + method = getattr(document, method_name, None) + if callable(method): + markdown = method() + if isinstance(markdown, str): + return markdown + markdown_attr = getattr(document, "markdown", None) + if isinstance(markdown_attr, str): + return markdown_attr + raise ParseError("Docling document did not expose a Markdown export method.") + + +def _docling_provenance(result: Any, document: Any) -> dict[str, Any]: + provenance: dict[str, Any] = { + "result_type": type(result).__name__, + "document_type": type(document).__name__, + } + status = getattr(result, "status", None) + if status is not None: + provenance["status"] = str(status) + timings = getattr(result, "timings", None) + if timings is not None: + provenance["timings"] = str(timings) + return provenance + + +def _page_records(candidate: ParseCandidate, profile: DocumentProfile) -> list[dict[str, Any]]: + candidate_pages = {page.get("page_num", 1) for page in candidate.pages} + profile_pages = {page.page_num for page in profile.pages} + page_nums = sorted(candidate_pages | profile_pages | {1}) + return [ + { + "page_num": page_num, + "source_parser": "docling", + "profile_available": page_num in profile_pages, + "markdown_page_marker": page_num in candidate_pages, + } + for page_num in page_nums + ] diff --git a/zsgdp/parsers/external.py b/zsgdp/parsers/external.py new file mode 100644 index 0000000000000000000000000000000000000000..757576fb8a6f674de2a7fc10e9c7ed3a13608f8e --- /dev/null +++ b/zsgdp/parsers/external.py @@ -0,0 +1,264 @@ +"""Adapters for external parsers that are optional in the MVP.""" + +from __future__ import annotations + +from importlib.util import find_spec +from pathlib import Path +import shutil +import subprocess +import tempfile +from typing import Any + +from zsgdp.normalize.normalize_marker import normalize_marker_markdown +from zsgdp.normalize.normalize_mineru import normalize_mineru_markdown +from zsgdp.normalize.normalize_olmocr import normalize_olmocr_markdown +from zsgdp.normalize.normalize_paddleocr import normalize_paddleocr_markdown +from zsgdp.normalize.normalize_unstructured import normalize_unstructured_parts +from zsgdp.parsers.base import BaseParser, ParseError, ParserUnavailableError +from zsgdp.schema import DocumentProfile, ParseCandidate + +DEFAULT_COMMANDS = { + "marker": ("marker_single", "marker"), + "mineru": ("mineru", "magic-pdf"), + "olmocr": ("olmocr", "olmocr-pipeline"), + "paddleocr": ("paddleocr", "paddleocr-structure"), +} + + +class MarkerParser(BaseParser): + name = "marker" + supported_file_types = {"pdf", "docx", "pptx", "xlsx", "html", "image", "epub"} + + def available(self) -> bool: + return _external_parser_available("marker") + + def parse( + self, + path: str | Path, + profile: DocumentProfile, + config: dict[str, Any], + *, + pages: list[int] | None = None, + ) -> ParseCandidate: + if not self.available(): + raise ParserUnavailableError("Marker is not installed.") + markdown = run_marker_to_markdown(path, config) + candidate = normalize_marker_markdown(markdown=markdown, profile=profile, source_path=str(path)) + _annotate_external_candidate(candidate, profile, pages) + return candidate + + +class MinerUParser(BaseParser): + name = "mineru" + supported_file_types = {"pdf", "docx", "pptx", "xlsx", "image"} + + def available(self) -> bool: + return _external_parser_available("mineru") + + def parse( + self, + path: str | Path, + profile: DocumentProfile, + config: dict[str, Any], + *, + pages: list[int] | None = None, + ) -> ParseCandidate: + if not self.available(): + raise ParserUnavailableError("MinerU is not installed.") + markdown = run_external_parser_to_markdown("mineru", path, config) + candidate = normalize_mineru_markdown(markdown=markdown, profile=profile, source_path=str(path)) + _annotate_external_candidate(candidate, profile, pages) + return candidate + + +class OlmOCRParser(BaseParser): + name = "olmocr" + supported_file_types = {"pdf", "image"} + + def available(self) -> bool: + return _external_parser_available("olmocr") + + def parse( + self, + path: str | Path, + profile: DocumentProfile, + config: dict[str, Any], + *, + pages: list[int] | None = None, + ) -> ParseCandidate: + if not self.available(): + raise ParserUnavailableError("olmOCR is not installed.") + markdown = run_external_parser_to_markdown("olmocr", path, config) + candidate = normalize_olmocr_markdown(markdown=markdown, profile=profile, source_path=str(path)) + _annotate_external_candidate(candidate, profile, pages) + return candidate + + +class PaddleOCRParser(BaseParser): + name = "paddleocr" + supported_file_types = {"pdf", "image"} + + def available(self) -> bool: + return _external_parser_available("paddleocr") + + def parse( + self, + path: str | Path, + profile: DocumentProfile, + config: dict[str, Any], + *, + pages: list[int] | None = None, + ) -> ParseCandidate: + if not self.available(): + raise ParserUnavailableError("PaddleOCR is not installed.") + markdown = run_external_parser_to_markdown("paddleocr", path, config) + candidate = normalize_paddleocr_markdown(markdown=markdown, profile=profile, source_path=str(path)) + _annotate_external_candidate(candidate, profile, pages) + return candidate + + +class UnstructuredParser(BaseParser): + name = "unstructured" + supported_file_types = {"pdf", "docx", "pptx", "xlsx", "html", "image", "epub", "text", "markdown"} + + def available(self) -> bool: + return find_spec("unstructured") is not None + + def parse( + self, + path: str | Path, + profile: DocumentProfile, + config: dict[str, Any], + *, + pages: list[int] | None = None, + ) -> ParseCandidate: + if not self.available(): + raise ParserUnavailableError("Unstructured is not installed.") + + try: + from unstructured.partition.auto import partition # type: ignore + + parts = partition(filename=str(path)) + except Exception as exc: + raise ParseError(f"Unstructured failed to parse {path}: {exc}") from exc + + candidate = normalize_unstructured_parts(parts=parts, profile=profile, source_path=str(path)) + candidate.pages = [{"page_num": page.page_num, "source_parser": self.name} for page in profile.pages] + return candidate + + +def run_marker_to_markdown(path: str | Path, config: dict[str, Any]) -> str: + return run_external_parser_to_markdown( + "marker", + path, + config, + default_output_args=["--output_dir", "{output_dir}", "--output_format", "markdown"], + ) + + +def run_external_parser_to_markdown( + parser_name: str, + path: str | Path, + config: dict[str, Any], + *, + default_output_args: list[str] | None = None, +) -> str: + parser_config = config.get("parsers", {}).get(parser_name, {}) + command = parser_config.get("command") or _parser_command(parser_name) + if not command: + raise ParserUnavailableError( + f"{parser_name} CLI was not found. Install {parser_name} or set parsers.{parser_name}.command." + ) + + timeout = float(parser_config.get("timeout_seconds", 300)) + output_args = _coerce_args(parser_config.get("output_args", default_output_args or ["--output_dir", "{output_dir}"]), parser_name, "output_args") + extra_args = _coerce_args(parser_config.get("extra_args", []), parser_name, "extra_args") + + with tempfile.TemporaryDirectory(prefix=f"zsgdp_{parser_name}_") as tmp: + output_dir = Path(tmp) / f"{parser_name}_out" + output_dir.mkdir(parents=True, exist_ok=True) + cmd = [ + str(command), + str(path), + *[_format_external_arg(arg, output_dir=output_dir, input_path=Path(path)) for arg in output_args], + *[str(arg) for arg in extra_args], + ] + try: + completed = subprocess.run( + cmd, + check=False, + capture_output=True, + text=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired as exc: + raise ParseError(f"{parser_name} timed out after {timeout:.1f}s while parsing {path}.") from exc + except OSError as exc: + raise ParserUnavailableError(f"{parser_name} failed to start: {exc}") from exc + + if completed.returncode != 0: + stderr = completed.stderr.strip() or completed.stdout.strip() + raise ParseError(f"{parser_name} failed with exit code {completed.returncode}: {stderr}") + return _read_external_markdown(output_dir, parser_name=parser_name, stdout=completed.stdout) + + +def _marker_command() -> str | None: + return _parser_command("marker") + + +def _parser_command(parser_name: str) -> str | None: + for command in DEFAULT_COMMANDS.get(parser_name, (parser_name,)): + resolved = shutil.which(command) + if resolved: + return resolved + return None + + +def _external_parser_available(parser_name: str) -> bool: + return find_spec(parser_name.replace("-", "_")) is not None or _parser_command(parser_name) is not None + + +def _read_marker_markdown(output_dir: Path) -> str: + return _read_external_markdown(output_dir, parser_name="marker") + + +def _read_external_markdown(output_dir: Path, *, parser_name: str, stdout: str = "") -> str: + markdown_files = sorted(output_dir.rglob("*.md")) + if markdown_files: + preferred = sorted(markdown_files, key=lambda item: (item.name != "markdown.md", len(item.parts), item.name))[0] + markdown = preferred.read_text(encoding="utf-8", errors="replace").strip() + if markdown: + return markdown + raise ParseError(f"{parser_name} produced an empty Markdown file: {preferred}") + + stdout_text = stdout.strip() + if stdout_text: + return stdout_text + raise ParseError(f"{parser_name} did not produce Markdown in {output_dir} or stdout.") + + +def _coerce_args(value: Any, parser_name: str, key: str) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + return value.split() + if isinstance(value, list): + return [str(item) for item in value] + raise ParseError(f"parsers.{parser_name}.{key} must be a string or list of strings.") + + +def _format_external_arg(arg: str, *, output_dir: Path, input_path: Path) -> str: + return arg.format(output_dir=str(output_dir), input_path=str(input_path)) + + +def _annotate_external_candidate( + candidate: ParseCandidate, + profile: DocumentProfile, + pages: list[int] | None, +) -> None: + candidate.provenance.update( + { + "requested_pages": pages or [page.page_num for page in profile.pages], + "profile_page_count": profile.page_count, + } + ) diff --git a/zsgdp/parsers/pymupdf_parser.py b/zsgdp/parsers/pymupdf_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a9cc7217b69cbb4f3c58f3407aa0a94dd9f940 --- /dev/null +++ b/zsgdp/parsers/pymupdf_parser.py @@ -0,0 +1,437 @@ +"""PyMuPDF parser adapter with layout, rendering, and crop artifacts.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import re +from statistics import median +from typing import Any + +from zsgdp.parsers.base import BaseParser, ParserUnavailableError +from zsgdp.profiling.page_render import crop_pdf_region, render_pdf_pages +from zsgdp.schema import DocumentProfile, Element, FigureObject, ParseCandidate, TableObject + + +TABLE_SPLIT_RE = re.compile(r"\t+|\s{2,}") +LIST_RE = re.compile(r"^(\d+[.)]|[A-Za-z][.)]|[-*+])\s+") +CAPTION_RE = re.compile(r"^(figure|fig\.|table|chart|exhibit)\s+\d*", re.IGNORECASE) +FORMULA_RE = re.compile(r"(\\(?:frac|sum|int|sqrt|alpha|beta|gamma|theta|sigma)|[=+\-*/^_<>]{2,})") + + +@dataclass(slots=True) +class TextBlock: + page_num: int + text: str + bbox: tuple[float, float, float, float] + block_no: int + max_font_size: float = 0.0 + avg_font_size: float = 0.0 + bold: bool = False + line_count: int = 0 + + +class PyMuPDFParser(BaseParser): + name = "pymupdf" + supported_file_types = {"pdf"} + + def available(self) -> bool: + try: + import fitz # type: ignore # noqa: F401 + except ImportError: + return False + return True + + def parse( + self, + path: str | Path, + profile: DocumentProfile, + config: dict[str, Any], + *, + pages: list[int] | None = None, + ) -> ParseCandidate: + try: + import fitz # type: ignore + except ImportError as exc: + raise ParserUnavailableError("PyMuPDF is not installed. Install with `python -m pip install -e '.[pdf]'`.") from exc + + selected_pages = set(pages or [page.page_num for page in profile.pages]) + elements: list[Element] = [] + tables: list[TableObject] = [] + figures: list[FigureObject] = [] + page_records: list[dict[str, Any]] = [] + provenance: dict[str, Any] = {"parser": self.name, "mode": "dict_blocks"} + pdf_config = config.get("pdf", {}) + dpi = int(pdf_config.get("render_dpi", 150)) + asset_root = _asset_root(config) + rendered_pages = {} + + if asset_root and pdf_config.get("render_pages", True): + try: + rendered_pages = { + rendered.page_num: rendered.to_dict() + for rendered in render_pdf_pages(path, asset_root / "pages", pages=selected_pages, dpi=dpi) + } + except Exception as exc: # Rendering should not prevent text parsing. + provenance["render_warning"] = str(exc) + + with fitz.open(path) as pdf: + for index, page in enumerate(pdf, start=1): + if index not in selected_pages: + continue + + page_record = { + "page_num": index, + "source_parser": self.name, + "width": float(page.rect.width), + "height": float(page.rect.height), + "rotation": int(page.rotation or 0), + } + if index in rendered_pages: + page_record["rendered_page"] = rendered_pages[index] + page_records.append(page_record) + + text_blocks = _sort_blocks_reading_order( + _extract_text_blocks(page.get_text("dict"), index), + float(page.rect.width), + ) + page_font_median = median([block.avg_font_size for block in text_blocks if block.avg_font_size] or [0.0]) + + table_count = 0 + for order, block in enumerate(text_blocks, start=1): + element_type = _guess_element_type(block, order, page_font_median, float(page.rect.height)) + markdown = None + text = block.text + confidence = _confidence_for_type(element_type) + provenance_data = { + "block_no": block.block_no, + "max_font_size": block.max_font_size, + "avg_font_size": block.avg_font_size, + "bold": block.bold, + "line_count": block.line_count, + } + + if element_type == "table": + table_count += 1 + table_id = f"{self.name}_p{index}_t{table_count}" + markdown = _table_text_to_markdown(block.text) + crop_path = None + if asset_root and pdf_config.get("crop_tables", True) and _valid_bbox(block.bbox): + crop_path = asset_root / "tables" / f"p{index:04d}_t{table_count:03d}.png" + try: + crop = crop_pdf_region(path, index, block.bbox, crop_path, dpi=dpi) + provenance_data["table_crop"] = crop.to_dict() + except Exception as exc: + provenance_data["table_crop_warning"] = str(exc) + tables.append( + TableObject( + table_id=table_id, + page_nums=[index], + bbox=[block.bbox], + markdown=markdown, + confidence=0.72, + source_parser=self.name, + provenance={ + "element_id": f"{self.name}_p{index}_e{order}", + "crop_path": str(crop_path) if crop_path else None, + }, + ) + ) + text = None + + elements.append( + Element( + element_id=f"{self.name}_p{index}_e{order}", + doc_id=profile.doc_id, + page_num=index, + type=element_type, + text=text, + markdown=markdown, + bbox=block.bbox, + reading_order=order, + confidence=confidence, + source_parser=self.name, + provenance=provenance_data, + ) + ) + + fig_index = 0 + for image in _image_infos(page): + fig_index += 1 + bbox = _coerce_bbox(image.get("bbox")) + image_path = None + fig_provenance = { + "source": "image", + "xref": image.get("xref"), + "width": image.get("width"), + "height": image.get("height"), + } + if asset_root and pdf_config.get("crop_figures", True) and bbox and _valid_bbox(bbox): + crop_path = asset_root / "figures" / f"p{index:04d}_f{fig_index:03d}.png" + try: + crop = crop_pdf_region(path, index, bbox, crop_path, dpi=dpi) + image_path = crop.image_path + fig_provenance["figure_crop"] = crop.to_dict() + except Exception as exc: + fig_provenance["figure_crop_warning"] = str(exc) + figures.append( + FigureObject( + figure_id=f"{self.name}_p{index}_f{fig_index}", + page_num=index, + bbox=bbox, + image_path=image_path, + confidence=0.58 if bbox else 0.40, + source_parser=self.name, + provenance=fig_provenance, + ) + ) + + for bbox in _drawing_regions(page): + fig_index += 1 + image_path = None + fig_provenance = {"source": "drawing"} + if asset_root and pdf_config.get("crop_figures", True) and _valid_bbox(bbox): + crop_path = asset_root / "figures" / f"p{index:04d}_f{fig_index:03d}.png" + try: + crop = crop_pdf_region(path, index, bbox, crop_path, dpi=dpi) + image_path = crop.image_path + fig_provenance["figure_crop"] = crop.to_dict() + except Exception as exc: + fig_provenance["figure_crop_warning"] = str(exc) + figures.append( + FigureObject( + figure_id=f"{self.name}_p{index}_f{fig_index}", + page_num=index, + bbox=bbox, + image_path=image_path, + confidence=0.50, + source_parser=self.name, + provenance=fig_provenance, + ) + ) + + return ParseCandidate( + parser_name=self.name, + doc_id=profile.doc_id, + source_path=str(path), + file_type=profile.file_type, + pages=page_records, + elements=elements, + tables=tables, + figures=figures, + confidence=0.86 if elements else 0.35, + provenance=provenance, + ) + + +def _extract_text_blocks(page_dict: dict[str, Any], page_num: int) -> list[TextBlock]: + blocks: list[TextBlock] = [] + for raw_block in page_dict.get("blocks", []): + if raw_block.get("type", 0) != 0: + continue + lines: list[str] = [] + font_sizes: list[float] = [] + bold = False + for line in raw_block.get("lines", []): + spans = line.get("spans", []) + line_text = "".join(span.get("text", "") for span in spans).strip() + if line_text: + lines.append(line_text) + for span in spans: + size = float(span.get("size", 0.0) or 0.0) + if size: + font_sizes.append(size) + font_name = str(span.get("font", "")).lower() + flags = int(span.get("flags", 0) or 0) + bold = bold or "bold" in font_name or bool(flags & 16) + text = "\n".join(lines).strip() + bbox = _coerce_bbox(raw_block.get("bbox")) + if text and bbox: + blocks.append( + TextBlock( + page_num=page_num, + text=text, + bbox=bbox, + block_no=int(raw_block.get("number", len(blocks))), + max_font_size=max(font_sizes) if font_sizes else 0.0, + avg_font_size=(sum(font_sizes) / len(font_sizes)) if font_sizes else 0.0, + bold=bold, + line_count=len(lines), + ) + ) + return blocks + + +def _sort_blocks_reading_order(blocks: list[TextBlock], page_width: float) -> list[TextBlock]: + if not blocks: + return [] + two_columns = _has_two_columns(blocks, page_width) + if not two_columns: + return sorted(blocks, key=lambda block: (round(block.bbox[1], 1), round(block.bbox[0], 1))) + return sorted(blocks, key=lambda block: (_column_bucket(block, page_width), round(block.bbox[1], 1), round(block.bbox[0], 1))) + + +def _has_two_columns(blocks: list[TextBlock], page_width: float) -> bool: + if len(blocks) < 6: + return False + centers = [(block.bbox[0] + block.bbox[2]) / 2 for block in blocks] + left = [center for center in centers if center < page_width * 0.45] + right = [center for center in centers if center > page_width * 0.55] + return len(left) >= 3 and len(right) >= 3 + + +def _column_bucket(block: TextBlock, page_width: float) -> int: + center = (block.bbox[0] + block.bbox[2]) / 2 + return 0 if center < page_width * 0.5 else 1 + + +def _guess_element_type(block: TextBlock, order: int, page_font_median: float, page_height: float) -> str: + compact = " ".join(block.text.split()) + if _is_table_text(block.text): + return "table" + if _looks_like_formula(block.text): + return "formula" + if CAPTION_RE.match(compact): + return "caption" + if order == 1 and len(compact) <= 140: + return "title" + if block.bbox[1] < page_height * 0.07 and len(compact) <= 100: + return "header" + if block.bbox[3] > page_height * 0.93 and len(compact) <= 100: + return "footer" + if block.max_font_size and page_font_median and block.max_font_size >= page_font_median * 1.25 and len(compact) <= 160: + return "heading" + if block.bold and len(compact) <= 120 and not compact.endswith("."): + return "heading" + if compact.endswith(":") and len(compact) <= 120: + return "heading" + if LIST_RE.match(compact): + return "list_item" + return "paragraph" + + +def _confidence_for_type(element_type: str) -> float: + return { + "title": 0.88, + "heading": 0.84, + "paragraph": 0.86, + "list_item": 0.82, + "table": 0.72, + "formula": 0.70, + "caption": 0.78, + "header": 0.75, + "footer": 0.75, + }.get(element_type, 0.80) + + +def _is_table_text(text: str) -> bool: + lines = [line.strip() for line in text.splitlines() if line.strip()] + if len(lines) < 2: + return False + if sum(1 for line in lines if line.startswith("|") and line.endswith("|")) >= 2: + return True + split_rows = [_split_table_row(line) for line in lines] + multi_cell_rows = [row for row in split_rows if len(row) >= 2] + numeric_rows = [row for row in multi_cell_rows if any(any(char.isdigit() for char in cell) for cell in row)] + return len(multi_cell_rows) >= 2 and (len(numeric_rows) >= 1 or len(multi_cell_rows) >= 3) + + +def _table_text_to_markdown(text: str) -> str: + lines = [line.strip() for line in text.splitlines() if line.strip()] + if not lines: + return "" + rows = [_split_table_row(line) for line in lines] + rows = [row for row in rows if row] + if not rows: + return "" + width = max(len(row) for row in rows) + padded = [row + [""] * (width - len(row)) for row in rows] + markdown_lines = [_markdown_row(padded[0]), _markdown_row(["---"] * width)] + markdown_lines.extend(_markdown_row(row) for row in padded[1:]) + return "\n".join(markdown_lines) + + +def _split_table_row(line: str) -> list[str]: + stripped = line.strip() + if stripped.startswith("|") and stripped.endswith("|"): + return [_clean_cell(cell) for cell in stripped.strip("|").split("|")] + return [_clean_cell(cell) for cell in TABLE_SPLIT_RE.split(stripped) if cell.strip()] + + +def _markdown_row(cells: list[str]) -> str: + return "| " + " | ".join(_escape_markdown_cell(cell) for cell in cells) + " |" + + +def _clean_cell(cell: str) -> str: + return " ".join(cell.strip().split()) + + +def _escape_markdown_cell(cell: str) -> str: + return cell.replace("|", "\\|") + + +def _looks_like_formula(text: str) -> bool: + compact = "".join(text.split()) + if len(compact) < 3: + return False + symbol_count = sum(1 for char in compact if char in "=+-*/^_<>") + return bool(FORMULA_RE.search(text)) or symbol_count / max(len(compact), 1) > 0.18 + + +def _image_infos(page: Any) -> list[dict[str, Any]]: + try: + return list(page.get_image_info(xrefs=True)) + except Exception: + return [{"xref": image[0]} for image in page.get_images(full=True)] + + +def _drawing_regions(page: Any) -> list[tuple[float, float, float, float]]: + regions: list[tuple[float, float, float, float]] = [] + try: + drawings = page.get_drawings() + except Exception: + return regions + for drawing in drawings: + bbox = _coerce_bbox(drawing.get("rect")) + if bbox and _valid_bbox(bbox) and _bbox_area(bbox) >= 400: + regions.append(bbox) + return _dedupe_bboxes(regions) + + +def _coerce_bbox(raw_bbox: Any) -> tuple[float, float, float, float] | None: + if not raw_bbox or len(raw_bbox) != 4: + return None + return (float(raw_bbox[0]), float(raw_bbox[1]), float(raw_bbox[2]), float(raw_bbox[3])) + + +def _valid_bbox(bbox: tuple[float, float, float, float]) -> bool: + return bbox[2] > bbox[0] and bbox[3] > bbox[1] and (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) > 4 + + +def _bbox_area(bbox: tuple[float, float, float, float]) -> float: + return max(bbox[2] - bbox[0], 0.0) * max(bbox[3] - bbox[1], 0.0) + + +def _dedupe_bboxes(bboxes: list[tuple[float, float, float, float]]) -> list[tuple[float, float, float, float]]: + deduped: list[tuple[float, float, float, float]] = [] + for bbox in bboxes: + if not any(_bbox_overlap_ratio(bbox, existing) > 0.9 for existing in deduped): + deduped.append(bbox) + return deduped + + +def _bbox_overlap_ratio(a: tuple[float, float, float, float], b: tuple[float, float, float, float]) -> float: + ix0, iy0 = max(a[0], b[0]), max(a[1], b[1]) + ix1, iy1 = min(a[2], b[2]), min(a[3], b[3]) + if ix1 <= ix0 or iy1 <= iy0: + return 0.0 + intersection = (ix1 - ix0) * (iy1 - iy0) + return intersection / max(min(_bbox_area(a), _bbox_area(b)), 1.0) + + +def _asset_root(config: dict[str, Any]) -> Path | None: + output_dir = config.get("runtime", {}).get("output_dir") + if not output_dir: + return None + asset_dir = config.get("pdf", {}).get("asset_dir", "assets") + return Path(output_dir) / str(asset_dir) diff --git a/zsgdp/parsers/registry.py b/zsgdp/parsers/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..10e58cbefe0eef35feb77f93dc945a4b0ed00922 --- /dev/null +++ b/zsgdp/parsers/registry.py @@ -0,0 +1,39 @@ +"""Parser registry.""" + +from __future__ import annotations + +from zsgdp.parsers.base import BaseParser +from zsgdp.parsers.docling_parser import DoclingParser +from zsgdp.parsers.external import ( + MarkerParser, + MinerUParser, + OlmOCRParser, + PaddleOCRParser, + UnstructuredParser, +) +from zsgdp.parsers.pymupdf_parser import PyMuPDFParser +from zsgdp.parsers.text_parser import TextParser + + +_PARSERS: dict[str, type[BaseParser]] = { + "text": TextParser, + "pymupdf": PyMuPDFParser, + "docling": DoclingParser, + "marker": MarkerParser, + "mineru": MinerUParser, + "olmocr": OlmOCRParser, + "paddleocr": PaddleOCRParser, + "unstructured": UnstructuredParser, +} + + +def parser_names() -> list[str]: + return sorted(_PARSERS) + + +def get_parser(name: str) -> BaseParser: + try: + parser_class = _PARSERS[name] + except KeyError as exc: + raise KeyError(f"Unknown parser: {name}") from exc + return parser_class() diff --git a/zsgdp/parsers/text_parser.py b/zsgdp/parsers/text_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..4daaf1f07698330faa48b1a1bba9ecb6a1349171 --- /dev/null +++ b/zsgdp/parsers/text_parser.py @@ -0,0 +1,153 @@ +"""Native text, Markdown, and simple HTML parser.""" + +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any + +from zsgdp.parsers.base import BaseParser +from zsgdp.schema import DocumentProfile, Element, ParseCandidate, TableObject + +HTML_TAG_RE = re.compile(r"<[^>]+>") + + +class TextParser(BaseParser): + name = "text" + supported_file_types = {"text", "markdown", "html"} + + def parse( + self, + path: str | Path, + profile: DocumentProfile, + config: dict[str, Any], + *, + pages: list[int] | None = None, + ) -> ParseCandidate: + path_obj = Path(path) + text = path_obj.read_text(encoding="utf-8", errors="replace") + if profile.file_type == "html": + text = HTML_TAG_RE.sub(" ", text) + return build_text_candidate( + text=text, + doc_id=profile.doc_id, + source_path=str(path_obj), + file_type=profile.file_type, + parser_name=self.name, + confidence=0.94, + ) + + +def build_text_candidate( + *, + text: str, + doc_id: str, + source_path: str, + file_type: str, + parser_name: str, + confidence: float, +) -> ParseCandidate: + blocks = _logical_blocks(text) + elements: list[Element] = [] + tables: list[TableObject] = [] + reading_order = 0 + table_index = 0 + + for block_type, block_text in blocks: + reading_order += 1 + element_id = f"{parser_name}_p1_e{reading_order}" + if block_type == "table": + table_index += 1 + table_id = f"{parser_name}_t{table_index}" + tables.append( + TableObject( + table_id=table_id, + page_nums=[1], + markdown=block_text, + confidence=0.82, + source_parser=parser_name, + provenance={"element_id": element_id}, + ) + ) + elements.append( + Element( + element_id=element_id, + doc_id=doc_id, + page_num=1, + type=block_type, + text=block_text if block_type != "table" else None, + markdown=block_text if block_type in {"heading", "title", "table", "list_item"} else None, + reading_order=reading_order, + confidence=confidence, + source_parser=parser_name, + provenance={"parser": parser_name}, + ) + ) + + return ParseCandidate( + parser_name=parser_name, + doc_id=doc_id, + source_path=source_path, + file_type=file_type, + pages=[{"page_num": 1, "source_parser": parser_name}], + elements=elements, + tables=tables, + confidence=confidence, + provenance={"parser": parser_name, "mode": "text_blocks"}, + ) + + +def _logical_blocks(text: str) -> list[tuple[str, str]]: + lines = text.splitlines() + blocks: list[tuple[str, str]] = [] + paragraph: list[str] = [] + table: list[str] = [] + + def flush_paragraph() -> None: + nonlocal paragraph + if paragraph: + joined = " ".join(line.strip() for line in paragraph if line.strip()) + if joined: + blocks.append(("paragraph", joined)) + paragraph = [] + + def flush_table() -> None: + nonlocal table + if table: + blocks.append(("table", "\n".join(table))) + table = [] + + for line in lines: + stripped = line.strip() + if not stripped: + flush_paragraph() + flush_table() + continue + + if _is_table_line(stripped): + flush_paragraph() + table.append(stripped) + continue + + flush_table() + if stripped.startswith("#"): + flush_paragraph() + heading_level = len(stripped) - len(stripped.lstrip("#")) + blocks.append(("title" if heading_level == 1 else "heading", stripped)) + elif stripped.startswith(("- ", "* ", "+ ")) or re.match(r"^\d+[.)]\s+", stripped): + flush_paragraph() + blocks.append(("list_item", stripped)) + else: + paragraph.append(stripped) + + flush_paragraph() + flush_table() + return blocks + + +def _is_table_line(line: str) -> bool: + if line.startswith("|") and line.endswith("|"): + return True + if "\t" in line: + return True + return bool(re.search(r"\S\s{2,}\S", line)) and sum(char.isdigit() for char in line) >= 1 diff --git a/zsgdp/pipeline.py b/zsgdp/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..858cd47eca1d754eb4a2b149dbf822ff9602b2fd --- /dev/null +++ b/zsgdp/pipeline.py @@ -0,0 +1,342 @@ +"""End-to-end parse pipeline.""" + +from __future__ import annotations + +from pathlib import Path +from time import perf_counter +from typing import Any, Iterable + +from zsgdp.chunking import build_agentic_chunks +from zsgdp.config import load_config +from zsgdp.export import export_parsed_document +from zsgdp.gpu import collect_gpu_runtime_status, dry_run_gpu_tasks, plan_gpu_tasks +from zsgdp.logging_config import get_logger +from zsgdp.merge import merge_candidates +from zsgdp.parsers.base import ParseError, ParserUnavailableError +from zsgdp.parsers.registry import get_parser +from zsgdp.profiling import profile_document +from zsgdp.repair import run_repair_loop +from zsgdp.routing import RouteDecision, route_document +from zsgdp.schema import DocumentProfile, ParsedDocument, QualityReport +from zsgdp.verify import ( + candidate_metrics, + compute_parser_disagreement, + compute_repair_success, + failure_metrics, + verify_chunks, + verify_parse, +) + +logger = get_logger(__name__) + + +def parse_document( + input_path: str | Path, + output_dir: str | Path | None = None, + *, + config_path: str | Path | None = None, + selected_parsers: Iterable[str] | None = None, + config_overrides: dict[str, Any] | None = None, +) -> ParsedDocument: + """Profile, route, parse, merge, verify, repair, chunk, and optionally export a document.""" + + path = Path(input_path) + config = load_config(config_path, config_overrides) + if output_dir is not None: + runtime = config.setdefault("runtime", {}) + runtime["output_dir"] = str(Path(output_dir)) + parse_started = perf_counter() + gpu_runtime = collect_gpu_runtime_status(config).to_dict() + profile = profile_document(path) + logger.info( + "parse_start", + extra={ + "doc_id": profile.doc_id, + "source_path": str(path), + "file_type": profile.file_type, + "page_count": profile.page_count, + "device": gpu_runtime.get("device"), + }, + ) + route_decisions = route_document(profile, config) + parser_pages = _parser_pages_from_routes(route_decisions) + + selected = list(selected_parsers or []) + if selected: + all_pages = [page.page_num for page in profile.pages] + parser_pages = {name: all_pages for name in selected} + + candidates = [] + parser_failures: dict[str, str] = {} + parser_metrics: dict[str, dict] = {} + for parser_name, pages in parser_pages.items(): + if parser_name.startswith("vlm_"): + continue + try: + parser = get_parser(parser_name) + except KeyError as exc: + parser_failures[parser_name] = str(exc) + parser_metrics[parser_name] = failure_metrics(parser_name, profile, str(exc), elapsed_seconds=None) + continue + + if not parser.supports(profile.file_type): + parser_failures[parser_name] = f"{parser_name} does not support file type {profile.file_type}." + parser_metrics[parser_name] = failure_metrics(parser_name, profile, parser_failures[parser_name], elapsed_seconds=None) + continue + if not parser.available(): + parser_failures[parser_name] = f"{parser_name} is not installed or unavailable." + parser_metrics[parser_name] = failure_metrics(parser_name, profile, parser_failures[parser_name], elapsed_seconds=None) + continue + + started = perf_counter() + try: + candidate = parser.parse(path, profile, config, pages=pages) + except (ParserUnavailableError, ParseError, RuntimeError) as exc: + elapsed = perf_counter() - started + parser_failures[parser_name] = str(exc) + parser_metrics[parser_name] = failure_metrics(parser_name, profile, str(exc), elapsed_seconds=elapsed) + logger.warning( + "parser_failed", + extra={ + "doc_id": profile.doc_id, + "parser": parser_name, + "elapsed_seconds": round(elapsed, 3), + "error": str(exc), + }, + ) + else: + elapsed = perf_counter() - started + candidate.provenance["elapsed_seconds"] = elapsed + candidates.append(candidate) + parser_metrics[parser_name] = candidate_metrics(candidate, profile, elapsed_seconds=elapsed) + logger.info( + "parser_candidate", + extra={ + "doc_id": profile.doc_id, + "parser": parser_name, + "elapsed_seconds": round(elapsed, 3), + "element_count": len(candidate.elements), + "table_count": len(candidate.tables), + "figure_count": len(candidate.figures), + }, + ) + + if candidates: + parsed = merge_candidates(candidates, profile) + else: + parsed = _empty_parsed_document(profile) + parsed.quality_report.add_issue( + "no_parser_candidate", + "error", + "No selected parser produced a candidate parse.", + blocking=True, + metadata={"parser_failures": parser_failures}, + ) + + disagreement = compute_parser_disagreement(parsed.provenance.get("conflict_report"), parser_metrics) + parsed.provenance.update( + { + "profile": profile.to_dict(), + "routing_report": [decision.to_dict() for decision in route_decisions], + "parser_failures": parser_failures, + "parser_metrics": parser_metrics, + "candidates": _serialize_candidates(candidates), + "parser_disagreement": disagreement, + "config_deployment": config.get("deployment", {}), + "gpu_runtime": gpu_runtime, + } + ) + parsed.quality_report = verify_parse(profile, parsed, config) + parsed.quality_report.metrics["parser_disagreement_rate"] = disagreement["disagreement_rate"] + parsed = _run_iterative_repair(profile, parsed, config, gpu_runtime) + repair_success = compute_repair_success( + parsed.provenance.get("pre_repair_quality"), + parsed.quality_report.to_dict(), + parsed.provenance.get("repair_iterations"), + ) + parsed.provenance["repair_success"] = repair_success + parsed.quality_report.metrics["repair_resolution_rate"] = repair_success["repair_resolution_rate"] + parsed.quality_report.metrics["repair_regression_rate"] = repair_success["repair_regression_rate"] + parsed.quality_report.metrics["parser_disagreement_rate"] = disagreement["disagreement_rate"] + parsed.chunks = build_agentic_chunks(parsed, profile, config) + parsed.quality_report = verify_chunks(parsed, config) + parsed.provenance["gpu_tasks"] = plan_gpu_tasks(profile, parsed, config, route_decisions, gpu_runtime) + if parsed.provenance["gpu_tasks"] and config.get("gpu", {}).get("validate_tasks", True): + parsed.provenance["gpu_task_report"] = dry_run_gpu_tasks( + parsed.provenance["gpu_tasks"], + config=config, + runtime_status=gpu_runtime, + ) + + if output_dir is not None: + export_parsed_document( + parsed, + output_dir, + routing_report=[decision.to_dict() for decision in route_decisions], + profile=profile.to_dict(), + ) + + logger.info( + "parse_end", + extra={ + "doc_id": profile.doc_id, + "elapsed_seconds": round(perf_counter() - parse_started, 3), + "quality_score": round(parsed.quality_report.score, 3), + "blocking_failures": parsed.quality_report.has_blocking_failures, + "element_count": len(parsed.elements), + "table_count": len(parsed.tables), + "figure_count": len(parsed.figures), + "chunk_count": len(parsed.chunks), + "candidate_parsers": list(disagreement.get("successful_parsers") or []), + "repair_iterations": len(parsed.provenance.get("repair_iterations") or []), + }, + ) + return parsed + + +def _serialize_candidates(candidates: list) -> dict[str, dict[str, Any]]: + """Slim per-parser candidate snapshot for benchmark-time per-parser metrics. + + Only the attributes the GT-comparison extractors need are preserved + (bbox, type, page_num, markdown/html for tables, latex for formulas). + Candidates are keyed by parser_name; if the same parser ran twice, the + later candidate overwrites the earlier one (the merger doesn't support + parallel runs of the same parser today). + """ + + snapshot: dict[str, dict[str, Any]] = {} + for candidate in candidates: + snapshot[candidate.parser_name] = { + "parser_name": candidate.parser_name, + "confidence": candidate.confidence, + "elements": [ + { + "element_id": element.element_id, + "type": element.type, + "page_num": element.page_num, + "bbox": list(element.bbox) if element.bbox else None, + "text": element.text, + "markdown": element.markdown, + "html": element.html, + } + for element in candidate.elements + ], + "tables": [ + { + "table_id": table.table_id, + "page_nums": list(table.page_nums or []), + "bbox": [list(box) for box in table.bbox] if isinstance(table.bbox, list) else (list(table.bbox) if table.bbox else None), + "markdown": table.markdown, + "html": table.html, + } + for table in candidate.tables + ], + "figures": [ + { + "figure_id": figure.figure_id, + "page_num": figure.page_num, + "bbox": list(figure.bbox) if figure.bbox else None, + } + for figure in candidate.figures + ], + } + return snapshot + + +def _parser_pages_from_routes(decisions: list[RouteDecision]) -> dict[str, list[int]]: + parser_pages: dict[str, list[int]] = {} + for decision in decisions: + for expert in decision.experts: + if expert.startswith("vlm_"): + continue + parser_pages.setdefault(expert, []).append(decision.page_id) + return parser_pages + + +def _repair_changed_document(parsed: ParsedDocument) -> bool: + repair = parsed.provenance.get("repair", {}) + return bool(repair.get("actions")) + + +def _run_iterative_repair( + profile: DocumentProfile, + parsed: ParsedDocument, + config: dict[str, Any], + gpu_runtime: dict[str, Any], +) -> ParsedDocument: + repair_config = config.get("repair", {}) + if not repair_config.get("enabled", True): + return run_repair_loop(profile, parsed, config, runtime_status=gpu_runtime, iteration=0) + + max_iterations = max(1, int(repair_config.get("max_iterations", 3))) + accept_threshold = float(config.get("quality", {}).get("accept_threshold", 0.88)) + history: list[dict[str, Any]] = [] + + for iteration in range(1, max_iterations + 1): + if not _quality_needs_repair(parsed.quality_report, accept_threshold): + if not history: + parsed.provenance["repair"] = { + "enabled": True, + "status": "skipped_quality_accepted", + "actions": [], + "profile_page_count": profile.page_count, + } + break + + before_quality = parsed.quality_report.to_dict() + parsed = run_repair_loop( + profile, + parsed, + config, + runtime_status=gpu_runtime, + iteration=iteration, + ) + changed = _repair_changed_document(parsed) + if changed: + parsed.provenance.setdefault("pre_repair_quality", before_quality) + parsed.quality_report = QualityReport() + parsed.quality_report = verify_parse(profile, parsed, config) + + after_quality = parsed.quality_report.to_dict() + history.append( + { + "iteration": iteration, + "before_score": before_quality.get("score", 0.0), + "after_score": after_quality.get("score", 0.0), + "before_blocking_failures": sum(1 for issue in before_quality.get("issues", []) if issue.get("blocking")), + "after_blocking_failures": sum(1 for issue in after_quality.get("issues", []) if issue.get("blocking")), + "actions": parsed.provenance.get("repair", {}).get("actions", []), + "status": parsed.provenance.get("repair", {}).get("status"), + } + ) + + if not changed: + break + + parsed.provenance["repair_iterations"] = history + return parsed + + +def _quality_needs_repair(report: QualityReport, accept_threshold: float) -> bool: + repairable_issue_types = { + "empty_page", + "missing_text_coverage", + "low_ocr_confidence", + "reading_order_failure", + "invalid_table", + "missing_figure_caption", + "missing_figure_region", + } + issue_types = {issue.issue_type for issue in report.issues} + return report.score < accept_threshold or report.has_blocking_failures or bool(issue_types & repairable_issue_types) + + +def _empty_parsed_document(profile: DocumentProfile) -> ParsedDocument: + return ParsedDocument( + doc_id=profile.doc_id, + source_path=profile.source_path, + file_type=profile.file_type, + pages=[{"page_num": page.page_num, "profile": page.to_dict()} for page in profile.pages], + quality_report=QualityReport(score=0.0), + provenance={"candidate_parsers": []}, + ) diff --git a/zsgdp/preflight.py b/zsgdp/preflight.py new file mode 100644 index 0000000000000000000000000000000000000000..70e0e5a5bb766d96552c3e1e08e22dfa4e0167db --- /dev/null +++ b/zsgdp/preflight.py @@ -0,0 +1,164 @@ +"""Preflight runner — single-command deployment readiness check. + +Runs the local guards that should pass before pushing to a Hugging Face +Space. Each step exits non-zero on failure; the runner aggregates results +and returns 0 only when every selected step passed. + +Steps (all opt-out via flags): + +- unit: full unittest discover. Catches code regressions. +- regression: tests/regression/ snapshot fixtures. Catches behavior drift + in the parse pipeline. +- space_check: zsgdp.cli space-check. Verifies app.py, requirements.txt, + and README header are present and parseable. +- parsers: zsgdp.cli parsers. Confirms parser registry imports cleanly. +- benchmark: optional end-to-end smoke against tests/regression/fixtures/. + Off by default (adds 3-5s); enable with --benchmark. + +Each step's stdout/stderr is captured and surfaced on failure, suppressed +on success, so a clean preflight prints one summary line per step. +""" + +from __future__ import annotations + +import subprocess +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Sequence + + +@dataclass(slots=True) +class StepResult: + name: str + passed: bool + elapsed_seconds: float + output: str = "" + skipped: bool = False + skip_reason: str = "" + + +@dataclass(slots=True) +class PreflightResult: + steps: list[StepResult] = field(default_factory=list) + + @property + def passed(self) -> bool: + return all(step.passed or step.skipped for step in self.steps) + + @property + def failed_steps(self) -> list[StepResult]: + return [step for step in self.steps if not step.passed and not step.skipped] + + +def run_preflight( + *, + root: str | Path = ".", + skip_unit: bool = False, + skip_regression: bool = False, + skip_space_check: bool = False, + skip_parsers: bool = False, + run_benchmark: bool = False, + python: str | None = None, +) -> PreflightResult: + project_root = Path(root).resolve() + python_bin = python or sys.executable + result = PreflightResult() + + if not skip_unit: + result.steps.append(_run_step("unit", [python_bin, "-m", "unittest", "discover"], project_root)) + else: + result.steps.append(StepResult(name="unit", passed=True, elapsed_seconds=0.0, skipped=True, skip_reason="skipped")) + + if not skip_regression: + result.steps.append( + _run_step( + "regression", + [python_bin, "-m", "unittest", "tests.regression.test_regression", "-v"], + project_root, + ) + ) + else: + result.steps.append(StepResult(name="regression", passed=True, elapsed_seconds=0.0, skipped=True, skip_reason="skipped")) + + if not skip_space_check: + result.steps.append( + _run_step( + "space_check", + [python_bin, "-m", "zsgdp.cli", "space-check", "--root", str(project_root)], + project_root, + ) + ) + else: + result.steps.append(StepResult(name="space_check", passed=True, elapsed_seconds=0.0, skipped=True, skip_reason="skipped")) + + if not skip_parsers: + result.steps.append( + _run_step("parsers", [python_bin, "-m", "zsgdp.cli", "parsers"], project_root) + ) + else: + result.steps.append(StepResult(name="parsers", passed=True, elapsed_seconds=0.0, skipped=True, skip_reason="skipped")) + + if run_benchmark: + fixtures = project_root / "tests" / "regression" / "fixtures" + out_dir = project_root / "out" / "preflight_benchmark" + result.steps.append( + _run_step( + "benchmark", + [ + python_bin, + "-m", + "zsgdp.cli", + "benchmark", + "--input", + str(fixtures), + "--output", + str(out_dir), + ], + project_root, + ) + ) + + return result + + +def _run_step(name: str, command: Sequence[str], cwd: Path) -> StepResult: + started = time.perf_counter() + completed = subprocess.run( + list(command), + cwd=cwd, + capture_output=True, + text=True, + ) + elapsed = time.perf_counter() - started + output = (completed.stdout or "") + (completed.stderr or "") + return StepResult( + name=name, + passed=completed.returncode == 0, + elapsed_seconds=elapsed, + output=output, + ) + + +def format_summary(result: PreflightResult) -> str: + lines: list[str] = [] + for step in result.steps: + if step.skipped: + lines.append(f" [skip] {step.name}") + continue + status = "ok" if step.passed else "FAIL" + lines.append(f" [{status}] {step.name} ({step.elapsed_seconds:.2f}s)") + overall = "PASS" if result.passed else "FAIL" + lines.append(f"preflight: {overall}") + return "\n".join(lines) + + +def format_failures(result: PreflightResult) -> str: + if result.passed: + return "" + chunks: list[str] = [] + for step in result.failed_steps: + chunks.append(f"--- {step.name} output ---") + chunks.append(step.output.strip() or "(no output)") + return "\n".join(chunks) diff --git a/zsgdp/profiling/__init__.py b/zsgdp/profiling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2f745675028fa98b7102f242a6ee57aa211b69 --- /dev/null +++ b/zsgdp/profiling/__init__.py @@ -0,0 +1,4 @@ +from zsgdp.profiling.document_profile import profile_document +from zsgdp.profiling.page_render import CroppedRegion, RenderedPage, crop_pdf_region, render_pdf_pages + +__all__ = ["CroppedRegion", "RenderedPage", "crop_pdf_region", "profile_document", "render_pdf_pages"] diff --git a/zsgdp/profiling/document_profile.py b/zsgdp/profiling/document_profile.py new file mode 100644 index 0000000000000000000000000000000000000000..59cf9fb9ad5c16ce5c54092ac4c3ada6f0b1d5fb --- /dev/null +++ b/zsgdp/profiling/document_profile.py @@ -0,0 +1,82 @@ +"""Top-level document profiling entrypoint.""" + +from __future__ import annotations + +from pathlib import Path + +from zsgdp.profiling.heuristics import ( + aggregate_document_labels, + estimate_formula_density, + estimate_table_density, + estimate_text_quality, + labels_for_page, +) +from zsgdp.profiling.pdf_profile import profile_pdf +from zsgdp.schema import DocumentProfile, PageProfile +from zsgdp.utils import document_id_for_path, file_type_from_path + + +def profile_document(path: str | Path) -> DocumentProfile: + path_obj = Path(path) + if not path_obj.exists(): + raise FileNotFoundError(path_obj) + + doc_id = document_id_for_path(path_obj) + file_type = file_type_from_path(path_obj) + if file_type == "pdf": + return profile_pdf(path_obj, doc_id) + if file_type in {"text", "markdown", "html"}: + return _profile_text_like(path_obj, doc_id, file_type) + return _profile_unknown(path_obj, doc_id, file_type) + + +def _profile_text_like(path: Path, doc_id: str, file_type: str) -> DocumentProfile: + text = path.read_text(encoding="utf-8", errors="replace") + table_density, table_count = estimate_table_density(text) + formula_density = estimate_formula_density(text) + lines = [line for line in text.splitlines() if line.strip()] + char_count = len(text.strip()) + page = PageProfile( + page_num=1, + digital_text_chars=char_count, + text_block_count=len(lines), + avg_chars_per_text_block=(char_count / len(lines)) if lines else 0.0, + table_density=table_density, + table_candidate_count=table_count, + formula_density=formula_density, + digital_text_quality=estimate_text_quality(char_count, len(lines)), + scanned_score=0.0, + metadata={"line_count": len(lines)}, + ) + page.labels = labels_for_page(page) + pages = [page] + return DocumentProfile( + doc_id=doc_id, + source_path=str(path), + file_type=file_type, + page_count=1, + extension=path.suffix.lower(), + pages=pages, + labels=aggregate_document_labels(pages), + metadata={"profiler": "text_like"}, + ) + + +def _profile_unknown(path: Path, doc_id: str, file_type: str) -> DocumentProfile: + page = PageProfile( + page_num=1, + scanned_score=0.5 if file_type == "image" else 0.0, + digital_text_quality=0.0, + labels=["low_confidence"], + metadata={"warning": "No native profiler implemented for this file type yet."}, + ) + return DocumentProfile( + doc_id=doc_id, + source_path=str(path), + file_type=file_type, + page_count=1, + extension=path.suffix.lower(), + pages=[page], + labels=["low_confidence"], + metadata={"profiler": "unknown"}, + ) diff --git a/zsgdp/profiling/heuristics.py b/zsgdp/profiling/heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..7d10719466a2c6228bf566f8ac1d066948637c20 --- /dev/null +++ b/zsgdp/profiling/heuristics.py @@ -0,0 +1,67 @@ +"""Cheap profiling heuristics.""" + +from __future__ import annotations + +import re + +from zsgdp.schema import PageProfile +from zsgdp.utils import clamp + + +TABLE_ROW_RE = re.compile(r"(\s{2,}|\t|\|).*(\d|\w)") +FORMULA_SYMBOLS = set("=+-*/^_<>[]{}|") + + +def estimate_table_density(text: str) -> tuple[float, int]: + lines = [line for line in text.splitlines() if line.strip()] + if not lines: + return 0.0, 0 + table_like = [line for line in lines if TABLE_ROW_RE.search(line)] + markdown_table_lines = [line for line in lines if line.strip().startswith("|") and line.strip().endswith("|")] + count = len(set(table_like + markdown_table_lines)) + return clamp(count / max(len(lines), 1)), count + + +def estimate_formula_density(text: str) -> float: + compact = "".join(text.split()) + if not compact: + return 0.0 + symbol_count = sum(1 for char in compact if char in FORMULA_SYMBOLS) + latex_hits = len(re.findall(r"\\(?:frac|sum|int|sqrt|alpha|beta|gamma|theta|sigma)", text)) + return clamp((symbol_count + 3 * latex_hits) / max(len(compact), 1)) + + +def estimate_text_quality(char_count: int, block_count: int) -> float: + if char_count <= 0: + return 0.0 + block_bonus = min(block_count, 20) / 20 + char_score = clamp(char_count / 1200) + return clamp(0.8 * char_score + 0.2 * block_bonus) + + +def labels_for_page(page: PageProfile) -> list[str]: + labels: list[str] = [] + if page.digital_text_quality >= 0.65: + labels.append("digital_text") + if page.scanned_score >= 0.70: + labels.append("scanned_text") + if page.estimated_columns >= 2: + labels.append("multi_column") + if page.table_density >= 0.25: + labels.append("table_heavy") + if page.formula_density >= 0.15: + labels.append("formula_heavy") + if page.figure_density >= 0.20: + labels.append("figure_heavy") + if page.digital_text_quality < 0.25: + labels.append("low_confidence") + return labels + + +def aggregate_document_labels(pages: list[PageProfile]) -> list[str]: + counts: dict[str, int] = {} + for page in pages: + for label in page.labels: + counts[label] = counts.get(label, 0) + 1 + total = max(len(pages), 1) + return sorted(label for label, count in counts.items() if count / total >= 0.25) diff --git a/zsgdp/profiling/page_render.py b/zsgdp/profiling/page_render.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4e4748bf15792a716156afad7532888d164b11 --- /dev/null +++ b/zsgdp/profiling/page_render.py @@ -0,0 +1,135 @@ +"""PDF page rendering and region crop utilities.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + + +class PDFRenderDependencyError(RuntimeError): + """Raised when PyMuPDF is required for rendering but is not installed.""" + + +@dataclass(slots=True) +class RenderedPage: + page_num: int + image_path: str + width: float + height: float + dpi: int + scale: float + + def to_dict(self) -> dict[str, float | int | str]: + return { + "page_num": self.page_num, + "image_path": self.image_path, + "width": self.width, + "height": self.height, + "dpi": self.dpi, + "scale": self.scale, + } + + +@dataclass(slots=True) +class CroppedRegion: + page_num: int + image_path: str + bbox: tuple[float, float, float, float] + dpi: int + scale: float + + def to_dict(self) -> dict[str, float | int | str | tuple[float, float, float, float]]: + return { + "page_num": self.page_num, + "image_path": self.image_path, + "bbox": self.bbox, + "dpi": self.dpi, + "scale": self.scale, + } + + +def render_pdf_pages( + pdf_path: str | Path, + output_dir: str | Path, + *, + pages: Iterable[int] | None = None, + dpi: int = 150, + image_format: str = "png", +) -> list[RenderedPage]: + """Render selected one-indexed PDF pages to images.""" + + fitz = _require_fitz() + pdf_path = Path(pdf_path) + out = Path(output_dir) + out.mkdir(parents=True, exist_ok=True) + rendered: list[RenderedPage] = [] + selected_pages = set(pages or []) + scale = dpi / 72.0 + matrix = fitz.Matrix(scale, scale) + + with fitz.open(pdf_path) as pdf: + for index, page in enumerate(pdf, start=1): + if selected_pages and index not in selected_pages: + continue + output_path = out / f"page_{index:04d}.{image_format}" + pixmap = page.get_pixmap(matrix=matrix, alpha=False) + pixmap.save(output_path) + rendered.append( + RenderedPage( + page_num=index, + image_path=str(output_path), + width=float(page.rect.width), + height=float(page.rect.height), + dpi=dpi, + scale=scale, + ) + ) + return rendered + + +def crop_pdf_region( + pdf_path: str | Path, + page_num: int, + bbox: tuple[float, float, float, float], + output_path: str | Path, + *, + dpi: int = 150, + padding: float = 2.0, +) -> CroppedRegion: + """Crop a one-indexed page region expressed in PDF point coordinates.""" + + fitz = _require_fitz() + pdf_path = Path(pdf_path) + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + scale = dpi / 72.0 + + with fitz.open(pdf_path) as pdf: + page = pdf[page_num - 1] + clip = fitz.Rect(*bbox) + clip = fitz.Rect( + max(clip.x0 - padding, page.rect.x0), + max(clip.y0 - padding, page.rect.y0), + min(clip.x1 + padding, page.rect.x1), + min(clip.y1 + padding, page.rect.y1), + ) + pixmap = page.get_pixmap(matrix=fitz.Matrix(scale, scale), clip=clip, alpha=False) + pixmap.save(output_path) + return CroppedRegion( + page_num=page_num, + image_path=str(output_path), + bbox=(float(clip.x0), float(clip.y0), float(clip.x1), float(clip.y1)), + dpi=dpi, + scale=scale, + ) + + +def _require_fitz(): + try: + import fitz # type: ignore + except ImportError as exc: + raise PDFRenderDependencyError( + "PDF rendering requires PyMuPDF. Install with `python -m pip install -e '.[pdf]'`." + ) from exc + return fitz diff --git a/zsgdp/profiling/pdf_profile.py b/zsgdp/profiling/pdf_profile.py new file mode 100644 index 0000000000000000000000000000000000000000..dff54fc669eb78b6f276998738ccbb2d7ff9ce7b --- /dev/null +++ b/zsgdp/profiling/pdf_profile.py @@ -0,0 +1,156 @@ +"""PDF-specific document profiling.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from zsgdp.profiling.heuristics import ( + aggregate_document_labels, + estimate_formula_density, + estimate_table_density, + estimate_text_quality, + labels_for_page, +) +from zsgdp.schema import DocumentProfile, PageProfile + + +def profile_pdf(path: str | Path, doc_id: str) -> DocumentProfile: + path_obj = Path(path) + try: + import fitz # type: ignore + except ImportError: + return _fallback_pdf_profile(path_obj, doc_id) + + pages: list[PageProfile] = [] + metadata: dict[str, Any] = {"profiler": "pymupdf"} + with fitz.open(path_obj) as pdf: + metadata.update({key: value for key, value in (pdf.metadata or {}).items() if value}) + for index, page in enumerate(pdf, start=1): + text = page.get_text("text") or "" + blocks = [block for block in page.get_text("blocks") if len(block) >= 5 and str(block[4]).strip()] + table_density, table_count = estimate_table_density(text) + formula_density = estimate_formula_density(text) + image_infos = _image_infos(page) + image_area_ratio = _image_area_ratio(page, image_infos) + drawing_count = _drawing_count(page) + font_count = _font_count(page) + estimated_columns = _estimate_columns(blocks, page.rect.width) + char_count = len(text.strip()) + scanned_score = _scanned_score(char_count, image_area_ratio, len(image_infos)) + digital_quality = estimate_text_quality(char_count, len(blocks)) + + profile = PageProfile( + page_num=index, + digital_text_chars=char_count, + image_area_ratio=image_area_ratio, + num_images=len(image_infos), + num_drawings=drawing_count, + font_count=font_count, + avg_chars_per_text_block=(char_count / len(blocks)) if blocks else 0.0, + text_block_count=len(blocks), + estimated_columns=estimated_columns, + table_density=table_density, + table_candidate_count=table_count, + formula_density=formula_density, + figure_density=min(1.0, (len(image_infos) + drawing_count / 20) / 5), + scanned_score=scanned_score, + digital_text_quality=digital_quality, + page_rotation=int(page.rotation or 0), + metadata={ + "width": float(page.rect.width), + "height": float(page.rect.height), + }, + ) + profile.labels = labels_for_page(profile) + pages.append(profile) + + return DocumentProfile( + doc_id=doc_id, + source_path=str(path_obj), + file_type="pdf", + page_count=len(pages), + extension=path_obj.suffix.lower(), + pages=pages, + labels=aggregate_document_labels(pages), + metadata=metadata, + ) + + +def _fallback_pdf_profile(path: Path, doc_id: str) -> DocumentProfile: + data = path.read_bytes() + page_count = max(data.count(b"/Type /Page"), 1) + pages = [ + PageProfile( + page_num=page_num, + scanned_score=0.5, + digital_text_quality=0.0, + labels=["low_confidence"], + metadata={"profiler_warning": "PyMuPDF not installed; using byte-count fallback."}, + ) + for page_num in range(1, page_count + 1) + ] + return DocumentProfile( + doc_id=doc_id, + source_path=str(path), + file_type="pdf", + page_count=page_count, + extension=path.suffix.lower(), + pages=pages, + labels=aggregate_document_labels(pages), + metadata={"profiler": "fallback_pdf", "warning": "Install PyMuPDF for real PDF profiling."}, + ) + + +def _image_infos(page: Any) -> list[dict[str, Any]]: + try: + return list(page.get_image_info(xrefs=True)) + except Exception: + return [{"xref": image[0]} for image in page.get_images(full=True)] + + +def _image_area_ratio(page: Any, image_infos: list[dict[str, Any]]) -> float: + page_area = float(page.rect.width * page.rect.height) or 1.0 + total = 0.0 + for image in image_infos: + bbox = image.get("bbox") + if bbox and len(bbox) == 4: + total += max(float(bbox[2]) - float(bbox[0]), 0.0) * max(float(bbox[3]) - float(bbox[1]), 0.0) + return min(total / page_area, 1.0) + + +def _drawing_count(page: Any) -> int: + try: + return len(page.get_drawings()) + except Exception: + return 0 + + +def _font_count(page: Any) -> int: + try: + fonts = page.get_fonts(full=True) + except Exception: + return 0 + return len({font[3] for font in fonts if len(font) > 3}) + + +def _estimate_columns(blocks: list[tuple[Any, ...]], page_width: float) -> int: + if len(blocks) < 6: + return 1 + centers = sorted((float(block[0]) + float(block[2])) / 2 for block in blocks) + left = [center for center in centers if center < page_width * 0.45] + right = [center for center in centers if center > page_width * 0.55] + if len(left) >= 3 and len(right) >= 3: + return 2 + return 1 + + +def _scanned_score(char_count: int, image_area_ratio: float, image_count: int) -> float: + if char_count > 200: + text_penalty = 0.0 + elif char_count > 30: + text_penalty = 0.4 + else: + text_penalty = 0.8 + image_score = min(image_area_ratio + image_count * 0.15, 1.0) + return min(1.0, 0.65 * text_penalty + 0.35 * image_score) diff --git a/zsgdp/repair/__init__.py b/zsgdp/repair/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6653ef9eb09cbbd12a81abae89417430d136519 --- /dev/null +++ b/zsgdp/repair/__init__.py @@ -0,0 +1,3 @@ +from zsgdp.repair.controller import run_repair_loop + +__all__ = ["run_repair_loop"] diff --git a/zsgdp/repair/controller.py b/zsgdp/repair/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f428cc26e21e1ad311ef0bbc026f77716406b4 --- /dev/null +++ b/zsgdp/repair/controller.py @@ -0,0 +1,557 @@ +"""Repair-loop orchestration.""" + +from __future__ import annotations + +import json +import re +from typing import Any + +from zsgdp.gpu.runtime import collect_gpu_runtime_status +from zsgdp.gpu.tasks import make_gpu_task +from zsgdp.gpu.worker import GPUWorker +from zsgdp.logging_config import get_logger +from zsgdp.repair.header_footer import mark_repeated_headers_and_footers +from zsgdp.repair.reading_order_repair import repair_reading_order_elements +from zsgdp.repair.table_repair import repair_table +from zsgdp.schema import DocumentProfile, Element, ParsedDocument, TableObject + +logger = get_logger(__name__) + + +def run_repair_loop( + profile: DocumentProfile, + parsed: ParsedDocument, + config: dict, + *, + runtime_status: dict[str, Any] | None = None, + iteration: int = 1, +) -> ParsedDocument: + repair_config = config.get("repair", {}) + if not repair_config.get("enabled", True): + parsed.provenance["repair"] = {"enabled": False} + return parsed + + actions: list[dict] = [] + issue_types = {issue.issue_type for issue in parsed.quality_report.issues} + + if repair_config.get("reading_order_repair", True): + marked = mark_repeated_headers_and_footers(parsed.elements) + if marked: + actions.append({"action": "mark_repeated_headers_footers", "element_count": marked}) + if "reading_order_failure" in issue_types: + order_repair = repair_reading_order_elements(parsed.elements) + if order_repair["changed"]: + actions.append({"action": "repair_reading_order", **order_repair}) + + if repair_config.get("table_repair", True): + table_actions = _repair_tables(parsed) + if table_actions: + actions.extend(table_actions) + + if repair_config.get("figure_repair", True): + caption_count = _attach_nearby_figure_captions(parsed) + if caption_count: + actions.append({"action": "attach_figure_captions", "figure_count": caption_count}) + + gpu_escalation = _run_gpu_repair_escalations(profile, parsed, config, runtime_status=runtime_status) + if gpu_escalation.get("applied_actions"): + actions.extend(gpu_escalation["applied_actions"]) + + parsed.provenance["repair"] = { + "enabled": True, + "iteration": iteration, + "status": _repair_status(actions, gpu_escalation), + "actions": actions, + "blocking_failures_before_repair": [ + issue.to_dict() for issue in parsed.quality_report.issues if issue.blocking + ], + "available_escalations": ["vlm_table_repair", "vlm_figure_repair", "ocr_repair"], + "gpu_escalation": gpu_escalation, + "profile_page_count": profile.page_count, + } + logger.info( + "repair_iteration", + extra={ + "doc_id": parsed.doc_id, + "iteration": iteration, + "status": parsed.provenance["repair"]["status"], + "deterministic_action_count": len(actions), + "gpu_task_count": int(gpu_escalation.get("task_count", 0)), + "gpu_executed_count": int(gpu_escalation.get("executed_count", 0)), + "gpu_dry_run": bool(gpu_escalation.get("dry_run", True)), + }, + ) + return parsed + + +def _repair_tables(parsed: ParsedDocument) -> list[dict]: + actions: list[dict] = [] + elements_by_id = {element.element_id: element for element in parsed.elements} + + for table in parsed.tables: + action = _repair_table_object(table) + if action: + actions.append(action) + element_id = table.provenance.get("element_id") + element = elements_by_id.get(element_id) if element_id else None + if element is not None and table.markdown and element.markdown != table.markdown: + element.provenance.setdefault("repair_original_markdown", element.markdown) + element.markdown = table.markdown + actions.append({"action": "sync_table_element", "table_id": table.table_id, "element_id": element.element_id}) + + for element in parsed.elements: + if element.type == "table" and element.markdown: + repaired = repair_table(element.markdown) + repaired_markdown = repaired.get("markdown") + if repaired_markdown and repaired_markdown != element.markdown: + element.provenance.setdefault("repair_original_markdown", element.markdown) + element.markdown = repaired_markdown + actions.append({"action": "normalize_table_element", "element_id": element.element_id}) + + return actions + + +def _repair_table_object(table: TableObject) -> dict | None: + repaired = repair_table(table.markdown or "", table.provenance.get("crop_path")) + action: dict = {"action": "repair_table", "table_id": table.table_id, "changes": []} + + repaired_markdown = repaired.get("markdown") + if repaired_markdown and repaired_markdown != table.markdown: + table.provenance.setdefault("repair_original_markdown", table.markdown) + table.markdown = repaired_markdown + action["changes"].append("markdown") + + rendering = repaired.get("natural_language_rendering") + if rendering and rendering != table.natural_language_rendering: + table.natural_language_rendering = rendering + action["changes"].append("natural_language_rendering") + + return action if action["changes"] else None + + +def _run_gpu_repair_escalations( + profile: DocumentProfile, + parsed: ParsedDocument, + config: dict, + *, + runtime_status: dict[str, Any] | None = None, +) -> dict: + repair_config = config.get("repair", {}) + if not repair_config.get("gpu_escalation", True): + return {"enabled": False, "task_count": 0, "results": [], "applied_actions": []} + + runtime = runtime_status or collect_gpu_runtime_status(config).to_dict() + tasks = _plan_repair_gpu_tasks(profile, parsed, config, runtime) + if not tasks: + return {"enabled": True, "dry_run": True, "task_count": 0, "results": [], "applied_actions": []} + + execute = bool(repair_config.get("execute_gpu_escalations", False)) + worker = GPUWorker(config) + results = [worker.run(task, dry_run=not execute) for task in tasks] + applied_actions = _apply_gpu_repair_results(parsed, tasks, results) if execute else [] + return { + "enabled": True, + "dry_run": not execute, + "task_count": len(tasks), + "ready_count": sum(1 for result in results if result.get("readiness", {}).get("ready")), + "blocked_count": sum(1 for result in results if not result.get("readiness", {}).get("ready")), + "executed_count": sum(1 for result in results if result.get("status") == "executed"), + "failed_count": sum( + 1 + for result in results + if result.get("readiness", {}).get("ready") and result.get("status") not in {"ready_dry_run", "executed"} + ), + "tasks": tasks, + "results": results, + "applied_actions": applied_actions, + } + + +def _plan_repair_gpu_tasks( + profile: DocumentProfile, + parsed: ParsedDocument, + config: dict, + runtime: dict[str, Any], +) -> list[dict[str, Any]]: + repair_config = config.get("repair", {}) + budget = _repair_budget(config) + tasks: list[dict[str, Any]] = [] + planned_keys: set[tuple] = set() + + def add_task(key: tuple, task_type: str, page_nums: list[int], *, priority: int, reason: str, **kwargs: Any) -> None: + if key in planned_keys: + return + planned_keys.add(key) + tasks.append( + make_gpu_task( + f"gr{len(tasks) + 1}", + task_type, + parsed.doc_id, + page_nums, + priority=priority, + runtime=runtime, + config=config, + budget=budget, + reason=reason, + **kwargs, + ) + ) + + tables_by_id = {table.table_id: table for table in parsed.tables} + figures_by_id = {figure.figure_id: figure for figure in parsed.figures} + for issue in parsed.quality_report.issues: + if issue.issue_type == "invalid_table" and repair_config.get("table_repair", True): + table = tables_by_id.get(issue.region_id or "") + if table is None: + continue + add_task( + ("table_vlm_repair", table.table_id), + "table_vlm_repair", + table.page_nums or ([issue.page_num] if issue.page_num else [1]), + priority=95, + region_id=table.table_id, + image_path=table.provenance.get("crop_path"), + bbox=table.bbox, + reason=f"Verification flagged invalid table {table.table_id}.", + inputs={ + "markdown": table.markdown, + "natural_language_rendering": table.natural_language_rendering, + "quality_issue": issue.to_dict(), + }, + metadata={"repair_trigger": issue.issue_type, "source_parser": table.source_parser}, + ) + elif issue.issue_type in {"empty_page", "missing_text_coverage", "low_ocr_confidence"} and repair_config.get( + "ocr_repair", True + ): + page_num = issue.page_num or 1 + add_task( + ("ocr_page", page_num), + "ocr_page", + [page_num], + priority=90, + image_path=_rendered_page_path(parsed, page_num), + reason=f"Verification flagged OCR/text coverage issue on page {page_num}: {issue.issue_type}.", + inputs={"quality_issue": issue.to_dict(), "page_profile": _page_profile(profile, page_num)}, + metadata={"repair_trigger": issue.issue_type}, + ) + elif issue.issue_type == "reading_order_failure" and repair_config.get("reading_order_repair", True): + page_num = issue.page_num or 1 + add_task( + ("reading_order", page_num), + "vlm_route_repair", + [page_num], + priority=75, + image_path=_rendered_page_path(parsed, page_num), + reason=f"Verification flagged reading-order failure on page {page_num}.", + inputs={ + "expert": "reading_order_repair", + "quality_issue": issue.to_dict(), + "element_ids": [element.element_id for element in parsed.elements if element.page_num == page_num], + "page_markdown": _page_markdown(parsed, page_num), + }, + metadata={"repair_trigger": issue.issue_type, "repair_type": "reading_order"}, + ) + elif issue.issue_type in {"missing_figure_caption", "missing_figure_region"} and repair_config.get( + "figure_repair", True + ): + figure = figures_by_id.get(issue.region_id or "") + if figure is None or not figure.image_path: + continue + add_task( + ("figure_description", figure.figure_id), + "figure_description", + [figure.page_num], + priority=80, + region_id=figure.figure_id, + image_path=figure.image_path, + bbox=figure.bbox, + reason=f"Verification flagged figure issue {issue.issue_type} for {figure.figure_id}.", + inputs={"caption": figure.caption, "chart_data": figure.chart_data, "quality_issue": issue.to_dict()}, + metadata={"repair_trigger": issue.issue_type, "source_parser": figure.source_parser}, + ) + + max_calls = int(config.get("gpu", {}).get("max_vlm_calls_per_doc", 30)) + return sorted(tasks, key=lambda item: (-int(item.get("priority", 0)), str(item.get("task_id"))))[:max_calls] + + +def _apply_gpu_repair_results( + parsed: ParsedDocument, + tasks: list[dict[str, Any]], + results: list[dict[str, Any]], +) -> list[dict]: + tasks_by_id = {task["task_id"]: task for task in tasks} + actions: list[dict] = [] + for result in results: + if result.get("status") != "executed": + continue + task = tasks_by_id.get(result.get("task_id")) + if not task: + continue + task_type = task.get("task_type") + if task_type == "table_vlm_repair": + action = _apply_table_vlm_output(parsed, task, result) + if action: + actions.extend(action if isinstance(action, list) else [action]) + elif task_type == "figure_description": + action = _apply_figure_description_output(parsed, task, result) + if action: + actions.append(action) + elif task_type == "ocr_page": + action = _apply_ocr_page_output(parsed, task, result) + if action: + actions.append(action) + elif task_type == "vlm_route_repair" and task.get("metadata", {}).get("repair_type") == "reading_order": + action = _apply_reading_order_output(parsed, task, result) + if action: + actions.append(action) + return actions + + +def _apply_table_vlm_output(parsed: ParsedDocument, task: dict[str, Any], result: dict[str, Any]) -> list[dict] | None: + table_id = task.get("region_id") + table = next((candidate for candidate in parsed.tables if candidate.table_id == table_id), None) + if table is None: + return None + text = _result_text(result) + markdown = _extract_markdown_table(text) + if not markdown: + return None + + repaired = repair_table(markdown, task.get("image_path")) + actions: list[dict] = [] + changes: list[str] = [] + if repaired.get("markdown") and repaired["markdown"] != table.markdown: + table.provenance.setdefault("repair_original_markdown", table.markdown) + table.markdown = repaired["markdown"] + changes.append("markdown") + if repaired.get("natural_language_rendering") and repaired["natural_language_rendering"] != table.natural_language_rendering: + table.natural_language_rendering = repaired["natural_language_rendering"] + changes.append("natural_language_rendering") + if changes: + table.provenance["gpu_repair_task_id"] = task.get("task_id") + actions.append({"action": "apply_gpu_table_repair", "table_id": table.table_id, "changes": changes}) + sync_action = _sync_table_element(parsed, table) + if sync_action: + actions.append(sync_action) + return actions or None + + +def _apply_figure_description_output(parsed: ParsedDocument, task: dict[str, Any], result: dict[str, Any]) -> dict | None: + figure_id = task.get("region_id") + figure = next((candidate for candidate in parsed.figures if candidate.figure_id == figure_id), None) + if figure is None: + return None + description = _result_text(result).strip() + if not description or description == figure.vlm_description: + return None + figure.provenance.setdefault("repair_original_vlm_description", figure.vlm_description) + figure.vlm_description = description + figure.provenance["gpu_repair_task_id"] = task.get("task_id") + return {"action": "apply_gpu_figure_description", "figure_id": figure.figure_id} + + +def _apply_ocr_page_output(parsed: ParsedDocument, task: dict[str, Any], result: dict[str, Any]) -> dict | None: + text = _result_text(result).strip() + if not text: + return None + page_num = int((task.get("page_nums") or [1])[0]) + element_id = f"gpu_ocr_p{page_num}" + existing = next((element for element in parsed.elements if element.element_id == element_id), None) + if existing: + if existing.text == text: + return None + existing.provenance.setdefault("repair_original_text", existing.text) + existing.text = text + return {"action": "update_gpu_ocr_page_text", "page_num": page_num, "element_id": element_id} + + reading_order = max( + [element.reading_order or 0 for element in parsed.elements if element.page_num == page_num] or [0] + ) + 1 + parsed.elements.append( + Element( + element_id=element_id, + doc_id=parsed.doc_id, + page_num=page_num, + type="paragraph", + text=text, + reading_order=reading_order, + confidence=0.75, + source_parser="gpu_ocr", + provenance={"gpu_repair_task_id": task.get("task_id"), "repair_type": "ocr_page"}, + ) + ) + return {"action": "insert_gpu_ocr_page_text", "page_num": page_num, "element_id": element_id} + + +def _apply_reading_order_output(parsed: ParsedDocument, task: dict[str, Any], result: dict[str, Any]) -> dict | None: + text = _result_text(result) + ordered_ids = _parse_ordered_element_ids(text, set(task.get("inputs", {}).get("element_ids", []))) + if not ordered_ids: + return {"action": "record_gpu_reading_order_feedback", "page_nums": task.get("page_nums"), "task_id": task.get("task_id")} + + order_index = {element_id: index for index, element_id in enumerate(ordered_ids, start=1)} + changed = False + for element in parsed.elements: + if element.element_id not in order_index: + continue + if element.reading_order != order_index[element.element_id]: + element.provenance.setdefault("repair_original_reading_order", element.reading_order) + element.reading_order = order_index[element.element_id] + changed = True + if changed: + repair_reading_order_elements(parsed.elements) + return { + "action": "apply_gpu_reading_order", + "page_nums": task.get("page_nums"), + "element_ids": ordered_ids, + } + return None + + +def _sync_table_element(parsed: ParsedDocument, table: TableObject) -> dict | None: + element_id = table.provenance.get("element_id") + if not element_id or not table.markdown: + return None + element = next((candidate for candidate in parsed.elements if candidate.element_id == element_id), None) + if element is None or element.markdown == table.markdown: + return None + element.provenance.setdefault("repair_original_markdown", element.markdown) + element.markdown = table.markdown + return {"action": "sync_table_element", "table_id": table.table_id, "element_id": element.element_id} + + +def _attach_nearby_figure_captions(parsed: ParsedDocument) -> int: + captions_by_page: dict[int, list[Element]] = {} + for element in parsed.elements: + if element.type == "caption" and element.content().strip(): + captions_by_page.setdefault(element.page_num, []).append(element) + + attached = 0 + for figure in parsed.figures: + if figure.caption: + continue + caption = _best_caption_for_figure(captions_by_page.get(figure.page_num, [])) + if caption is None: + continue + figure.caption = caption.content().strip() + figure.provenance["caption_source_element_id"] = caption.element_id + attached += 1 + return attached + + +def _best_caption_for_figure(captions: list[Element]) -> Element | None: + figure_like = [caption for caption in captions if caption.content().strip().lower().startswith(("figure", "fig."))] + candidates = figure_like or captions + return min(candidates, key=lambda item: item.reading_order or 0) if candidates else None + + +def _repair_status(actions: list[dict], gpu_escalation: dict) -> str: + if actions and gpu_escalation.get("applied_actions"): + return "executed_deterministic_and_gpu" + if actions: + return "executed_deterministic" + if gpu_escalation.get("task_count") and gpu_escalation.get("dry_run"): + return "gpu_escalation_dry_run" + if gpu_escalation.get("task_count"): + return "gpu_escalation_executed_no_changes" + return "no_repairs_needed" + + +def _repair_budget(config: dict[str, Any]) -> dict[str, Any]: + gpu = config.get("gpu", {}) + repair = config.get("repair", {}) + return { + "max_gpu_seconds_per_doc": float(gpu.get("max_gpu_seconds_per_doc", 120)), + "max_vlm_calls_per_doc": int(gpu.get("max_vlm_calls_per_doc", 30)), + "max_repair_iterations": int(repair.get("max_iterations", 3)), + "max_batch_size": int(gpu.get("max_batch_size", 4)), + } + + +def _rendered_page_path(parsed: ParsedDocument, page_num: int) -> str | None: + for page in parsed.pages: + if int(page.get("page_num", 1)) != page_num: + continue + for parser_page in page.get("parser_pages", []): + rendered = parser_page.get("rendered_page") + if isinstance(rendered, dict) and rendered.get("image_path"): + return str(rendered["image_path"]) + rendered = page.get("rendered_page") + if isinstance(rendered, dict) and rendered.get("image_path"): + return str(rendered["image_path"]) + return None + + +def _page_profile(profile: DocumentProfile, page_num: int) -> dict | None: + for page in profile.pages: + if page.page_num == page_num: + return page.to_dict() + return None + + +def _page_markdown(parsed: ParsedDocument, page_num: int) -> str: + elements = sorted( + [element for element in parsed.elements if element.page_num == page_num], + key=lambda item: (item.reading_order is None, item.reading_order or 0, item.element_id), + ) + return "\n\n".join(element.content().strip() for element in elements if element.content().strip()) + + +def _result_text(result: dict[str, Any]) -> str: + output = result.get("output") if isinstance(result.get("output"), dict) else result + if not isinstance(output, dict): + return str(output or "") + for key in ("markdown", "vlm_description", "caption", "text", "answer", "content"): + if output.get(key): + return str(output[key]) + raw = output.get("raw_output") + if isinstance(raw, str): + return raw + return "" + + +def _extract_markdown_table(text: str) -> str | None: + if not text.strip(): + return None + fenced = re.findall(r"```(?:markdown|md)?\s*(.*?)```", text, flags=re.IGNORECASE | re.DOTALL) + candidates = fenced or [text] + for candidate in candidates: + pipe_lines = [line.strip() for line in candidate.splitlines() if line.strip().startswith("|")] + if len(pipe_lines) >= 2: + return "\n".join(pipe_lines) + return text.strip() if "|" in text else None + + +def _parse_ordered_element_ids(text: str, allowed_ids: set[str]) -> list[str]: + if not text.strip() or not allowed_ids: + return [] + parsed = _parse_json_order(text) + if parsed: + return _filter_order(parsed, allowed_ids) + mentions = re.findall(r"[A-Za-z][A-Za-z0-9_-]*", text) + return _filter_order(mentions, allowed_ids) + + +def _parse_json_order(text: str) -> list[str]: + try: + data = json.loads(text) + except json.JSONDecodeError: + return [] + if isinstance(data, list): + return [str(item) for item in data] + if isinstance(data, dict): + for key in ("element_ids", "reading_order", "order"): + value = data.get(key) + if isinstance(value, list): + return [str(item) for item in value] + return [] + + +def _filter_order(items: list[str], allowed_ids: set[str]) -> list[str]: + out: list[str] = [] + seen: set[str] = set() + for item in items: + if item in allowed_ids and item not in seen: + out.append(item) + seen.add(item) + return out diff --git a/zsgdp/repair/figure_repair.py b/zsgdp/repair/figure_repair.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae6193d771be7ae56e49fdaf23b5dc593721c6a --- /dev/null +++ b/zsgdp/repair/figure_repair.py @@ -0,0 +1,12 @@ +"""Figure repair extension point.""" + +from __future__ import annotations + + +def describe_figure(image_path: str | None, caption: str | None = None) -> dict[str, str | None]: + return { + "caption": caption, + "vlm_description": None, + "image_path": image_path, + "uncertainty": "Figure repair backend is not configured in the MVP.", + } diff --git a/zsgdp/repair/header_footer.py b/zsgdp/repair/header_footer.py new file mode 100644 index 0000000000000000000000000000000000000000..4ebbd4023cc02130decd02b39cecb026d99e0543 --- /dev/null +++ b/zsgdp/repair/header_footer.py @@ -0,0 +1,20 @@ +"""Header/footer cleanup helpers.""" + +from __future__ import annotations + +from collections import Counter + +from zsgdp.schema import Element + + +def mark_repeated_headers_and_footers(elements: list[Element], min_repetitions: int = 3) -> int: + counts = Counter(element.content().strip() for element in elements if element.content().strip()) + marked = 0 + for element in elements: + content = element.content().strip() + if counts[content] >= min_repetitions and len(content) <= 120: + if element.provenance.get("noise_candidate") == "repeated_header_footer": + continue + element.provenance["noise_candidate"] = "repeated_header_footer" + marked += 1 + return marked diff --git a/zsgdp/repair/ocr_repair.py b/zsgdp/repair/ocr_repair.py new file mode 100644 index 0000000000000000000000000000000000000000..1b62e2f9a698a11ca131fba43fc56833e882ad65 --- /dev/null +++ b/zsgdp/repair/ocr_repair.py @@ -0,0 +1,10 @@ +"""OCR repair extension point.""" + +from __future__ import annotations + + +def repair_ocr_text(text: str) -> dict[str, str]: + return { + "text": text, + "uncertainty": "OCR repair backend is not configured in the MVP.", + } diff --git a/zsgdp/repair/prompts.py b/zsgdp/repair/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..45d635070726a8fa8b8b33b66d21ce9935485749 --- /dev/null +++ b/zsgdp/repair/prompts.py @@ -0,0 +1,17 @@ +"""Prompt templates for future VLM/OCR repair backends.""" + +PAGE_TO_MARKDOWN_REPAIR = """You are repairing a parsed document page. Convert the provided page image into faithful Markdown. +Preserve headings, paragraphs, reading order, tables, equations, figure captions, and lists. +Do not summarize. Do not invent missing content. Mark uncertain text as [UNCERTAIN: ...]. +Return only Markdown.""" + +TABLE_REPAIR = """You are repairing a table extraction. +Given the table image and candidate extraction, return valid HTML, Markdown, concise natural-language rendering, and uncertainty notes. +Do not invent values. Preserve units, footnotes, headers, and merged-cell meaning.""" + +FIGURE_DESCRIPTION = """Describe this figure for document understanding. +Extract figure type, title/caption, axes and units if chart, legend labels, main trend, and visible numeric values. +Mark approximations explicitly.""" + +READING_ORDER_REPAIR = """Given page element boxes and the page image, produce the correct reading order. +Return element IDs only in order. Keep captions near their figures/tables.""" diff --git a/zsgdp/repair/reading_order_repair.py b/zsgdp/repair/reading_order_repair.py new file mode 100644 index 0000000000000000000000000000000000000000..0774ce33709ccfe3f818027291f5a62aab8f0a10 --- /dev/null +++ b/zsgdp/repair/reading_order_repair.py @@ -0,0 +1,44 @@ +"""Reading-order repair extension point.""" + +from __future__ import annotations + +from zsgdp.schema import Element + + +def repair_reading_order(element_ids: list[str]) -> list[str]: + return list(element_ids) + + +def repair_reading_order_elements(elements: list[Element]) -> dict: + """Reorder elements by page and recovered order, then normalize order ids.""" + + before = [element.element_id for element in elements] + sorted_elements = sorted(elements, key=_element_sort_key) + current_page: int | None = None + next_order = 1 + changed = before != [element.element_id for element in sorted_elements] + + for element in sorted_elements: + if element.page_num != current_page: + current_page = element.page_num + next_order = 1 + if element.reading_order != next_order: + changed = True + element.provenance.setdefault("repair_original_reading_order", element.reading_order) + element.reading_order = next_order + next_order += 1 + + if changed: + elements[:] = sorted_elements + + return { + "changed": changed, + "before_element_ids": before, + "after_element_ids": [element.element_id for element in elements], + } + + +def _element_sort_key(element: Element) -> tuple: + bbox = element.bbox or (0.0, 0.0, 0.0, 0.0) + reading_order = element.reading_order if element.reading_order is not None else 1_000_000 + return (element.page_num, reading_order, bbox[1], bbox[0], element.element_id) diff --git a/zsgdp/repair/table_repair.py b/zsgdp/repair/table_repair.py new file mode 100644 index 0000000000000000000000000000000000000000..d91e1c6cceaaaaa7d6cdf1cc578946f72c2f912d --- /dev/null +++ b/zsgdp/repair/table_repair.py @@ -0,0 +1,58 @@ +"""Table repair extension point.""" + +from __future__ import annotations + +import re + +from zsgdp.normalize.markdown import normalize_markdown_table + + +def repair_table(markdown: str, table_image_path: str | None = None) -> dict[str, str | None]: + normalized = normalize_markdown_table(markdown) if markdown else "" + return { + "html": None, + "markdown": normalized or markdown, + "natural_language_rendering": table_to_natural_language(normalized or markdown), + "uncertainty": None, + "table_image_path": table_image_path, + } + + +def table_to_natural_language(markdown: str | None) -> str | None: + rows = _markdown_rows(markdown or "") + if len(rows) < 2: + return None + + headers = rows[0] + data_rows = rows[1:] + if not headers or not data_rows: + return None + + row_descriptions: list[str] = [] + for row in data_rows: + label = row[0] if row else "row" + values = [] + for index, value in enumerate(row[1:], start=1): + header = headers[index] if index < len(headers) and headers[index] else f"column {index + 1}" + values.append(f"{header}={value}" if value else f"{header}=empty") + row_descriptions.append(f"{label}: " + ", ".join(values)) + + column_text = ", ".join(header or f"column {index + 1}" for index, header in enumerate(headers)) + return f"Table with columns {column_text}. Rows: " + "; ".join(row_descriptions) + "." + + +def _markdown_rows(markdown: str) -> list[list[str]]: + rows: list[list[str]] = [] + for line in markdown.splitlines(): + stripped = line.strip() + if not stripped.startswith("|") or not stripped.endswith("|"): + continue + cells = [_clean_cell(cell) for cell in stripped.strip("|").split("|")] + if cells and all(re.fullmatch(r":?-{3,}:?", cell) for cell in cells if cell): + continue + rows.append(cells) + return rows + + +def _clean_cell(cell: str) -> str: + return " ".join(cell.strip().split()) diff --git a/zsgdp/routing/__init__.py b/zsgdp/routing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..313416ae6370cc402168861beb8b0cdbc5d0447d --- /dev/null +++ b/zsgdp/routing/__init__.py @@ -0,0 +1,3 @@ +from zsgdp.routing.router import RouteDecision, route_document, route_page + +__all__ = ["RouteDecision", "route_document", "route_page"] diff --git a/zsgdp/routing/budgets.py b/zsgdp/routing/budgets.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0c808e06cf2fb247ceffe128043cb1ba512a7f --- /dev/null +++ b/zsgdp/routing/budgets.py @@ -0,0 +1,21 @@ +"""Parsing and repair budget schema.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(slots=True) +class Budget: + max_gpu_seconds: float = 60.0 + max_vlm_calls: int = 20 + max_repair_iterations: int = 3 + allow_expensive_models: bool = False + + def to_dict(self) -> dict[str, float | int | bool]: + return { + "max_gpu_seconds": self.max_gpu_seconds, + "max_vlm_calls": self.max_vlm_calls, + "max_repair_iterations": self.max_repair_iterations, + "allow_expensive_models": self.allow_expensive_models, + } diff --git a/zsgdp/routing/router.py b/zsgdp/routing/router.py new file mode 100644 index 0000000000000000000000000000000000000000..94602e548b28d952a546a5c22b3ae4428cf762bf --- /dev/null +++ b/zsgdp/routing/router.py @@ -0,0 +1,120 @@ +"""Deterministic page router.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from zsgdp.routing.budgets import Budget +from zsgdp.schema import DocumentProfile, PageProfile +from zsgdp.utils import to_plain_data + + +@dataclass(slots=True) +class RouteDecision: + page_id: int + experts: list[str] + reason: str + budget: Budget + labels: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + data = to_plain_data(self) + data["budget"] = self.budget.to_dict() + return data + + +def route_document(profile: DocumentProfile, config: dict[str, Any]) -> list[RouteDecision]: + return [route_page(page, profile, config) for page in profile.pages] + + +def route_page(page: PageProfile, profile: DocumentProfile, config: dict[str, Any]) -> RouteDecision: + routing = config.get("routing", {}) + gpu = config.get("gpu", {}) + experts: list[str] = [] + reasons: list[str] = [] + + if profile.file_type in {"text", "markdown", "html"}: + experts.append("text") + reasons.append("text-like document with native text available") + elif profile.file_type == "pdf": + if page.scanned_score > routing.get("scanned_text_threshold", 0.40): + experts.extend(["olmocr", "paddleocr", "pymupdf"]) + reasons.append("page appears scanned or has weak text layer") + + if page.table_density > routing.get("table_density_threshold", 0.25): + experts.extend(["docling", "marker", "pymupdf", "paddleocr"]) + reasons.append("table-heavy page") + + if page.formula_density > routing.get("formula_density_threshold", 0.15): + experts.extend(["mineru", "marker", "pymupdf"]) + reasons.append("formula-heavy page") + + if page.figure_density > routing.get("figure_density_threshold", 0.20): + experts.extend(["docling", "marker", "vlm_figure_repair", "pymupdf"]) + reasons.append("figure-heavy page") + + if page.digital_text_quality > 0.65: + experts.extend(["docling", "marker", "pymupdf"]) + reasons.append("digital text quality is strong") + + if not experts: + experts.extend(["pymupdf", "docling", "marker"]) + reasons.append("default PDF route") + elif profile.file_type == "image": + experts.extend(["olmocr", "paddleocr"]) + reasons.append("image document requires OCR/VLM parsing") + else: + experts.extend(["unstructured", "text"]) + reasons.append(f"default route for {profile.file_type}") + + experts = _filter_enabled(_dedupe(experts), config, fallback=_fallback_parser(profile.file_type)) + budget = Budget( + max_gpu_seconds=float(gpu.get("max_gpu_seconds_per_doc", 120)), + max_vlm_calls=int(gpu.get("max_vlm_calls_per_doc", 30)), + max_repair_iterations=int(config.get("repair", {}).get("max_iterations", 3)), + allow_expensive_models=bool(gpu.get("allow_expensive_models", False)), + ) + return RouteDecision( + page_id=page.page_num, + experts=experts, + reason="; ".join(reasons) if reasons else "default route", + budget=budget, + labels=list(page.labels), + metadata={ + "file_type": profile.file_type, + "deployment_target": config.get("deployment", {}).get("target"), + "gpu_models_target": config.get("deployment", {}).get("gpu_models_target"), + }, + ) + + +def _dedupe(items: list[str]) -> list[str]: + seen: set[str] = set() + deduped: list[str] = [] + for item in items: + if item not in seen: + deduped.append(item) + seen.add(item) + return deduped + + +def _filter_enabled(experts: list[str], config: dict[str, Any], fallback: str) -> list[str]: + parser_config = config.get("parsers", {}) + enabled = [ + expert + for expert in experts + if expert.startswith("vlm_") or parser_config.get(expert, {}).get("enabled", False) + ] + if fallback and fallback not in enabled and parser_config.get(fallback, {}).get("enabled", False): + enabled.append(fallback) + return enabled or ([fallback] if fallback else []) + + +def _fallback_parser(file_type: str) -> str: + if file_type in {"text", "markdown", "html"}: + return "text" + if file_type == "pdf": + return "pymupdf" + return "text" diff --git a/zsgdp/schema/__init__.py b/zsgdp/schema/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a57795b6e44dfebb2f7c14c060749ce5712920fc --- /dev/null +++ b/zsgdp/schema/__init__.py @@ -0,0 +1,26 @@ +from zsgdp.schema.document import ( + SCHEMA_VERSION, + Chunk, + Element, + FigureObject, + ParsedDocument, + ParseCandidate, + QualityIssue, + QualityReport, + TableObject, +) +from zsgdp.schema.profiles import DocumentProfile, PageProfile + +__all__ = [ + "Chunk", + "DocumentProfile", + "Element", + "FigureObject", + "PageProfile", + "ParsedDocument", + "ParseCandidate", + "QualityIssue", + "QualityReport", + "SCHEMA_VERSION", + "TableObject", +] diff --git a/zsgdp/schema/document.py b/zsgdp/schema/document.py new file mode 100644 index 0000000000000000000000000000000000000000..39d4c998bd5103d74fa002cc9d2e308c01b1d02c --- /dev/null +++ b/zsgdp/schema/document.py @@ -0,0 +1,206 @@ +"""Canonical parse schema.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from zsgdp.utils import to_plain_data + +# Bumped any time the on-disk shape of ParsedDocument, Element, TableObject, +# FigureObject, Chunk, or QualityReport changes in a way that older readers +# would not understand. Increment by 1 for additive changes (new optional +# fields), bump the major (e.g. 2.0 -> 3.0) for breaking renames or removals. +SCHEMA_VERSION = "1.0" + +BBox = tuple[float, float, float, float] + + +@dataclass(slots=True) +class Element: + element_id: str + doc_id: str + page_num: int + type: str + text: str | None = None + markdown: str | None = None + html: str | None = None + bbox: BBox | None = None + reading_order: int | None = None + confidence: float | None = None + source_parser: str = "unknown" + provenance: dict[str, Any] = field(default_factory=dict) + parent_id: str | None = None + + def content(self) -> str: + return self.markdown or self.text or self.html or "" + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +@dataclass(slots=True) +class TableObject: + table_id: str + page_nums: list[int] + bbox: list[BBox] | None = None + html: str | None = None + markdown: str | None = None + dataframe_json: dict[str, Any] | None = None + natural_language_rendering: str | None = None + caption: str | None = None + footnotes: list[str] = field(default_factory=list) + confidence: float = 0.0 + source_parser: str = "unknown" + provenance: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +@dataclass(slots=True) +class FigureObject: + figure_id: str + page_num: int + bbox: BBox | None = None + image_path: str | None = None + caption: str | None = None + vlm_description: str | None = None + chart_data: dict[str, Any] | None = None + confidence: float = 0.0 + source_parser: str = "unknown" + provenance: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +@dataclass(slots=True) +class Chunk: + chunk_id: str + doc_id: str + page_start: int + page_end: int + section_path: list[str] + content_type: str + text: str + element_ids: list[str] = field(default_factory=list) + table_ids: list[str] = field(default_factory=list) + figure_ids: list[str] = field(default_factory=list) + parent_chunk_id: str | None = None + child_chunk_ids: list[str] = field(default_factory=list) + strategy: str = "recursive_structure" + boundary_reason: str = "structure" + token_count: int = 0 + source_parser: str = "unknown" + quality_score: float = 0.0 + requires_visual_context: bool = False + context_prefix: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +@dataclass(slots=True) +class QualityIssue: + issue_type: str + severity: str + message: str + page_num: int | None = None + element_id: str | None = None + region_id: str | None = None + blocking: bool = False + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +@dataclass(slots=True) +class QualityReport: + score: float = 1.0 + issues: list[QualityIssue] = field(default_factory=list) + metrics: dict[str, Any] = field(default_factory=dict) + + @property + def has_blocking_failures(self) -> bool: + return any(issue.blocking for issue in self.issues) + + def add_issue( + self, + issue_type: str, + severity: str, + message: str, + *, + page_num: int | None = None, + element_id: str | None = None, + region_id: str | None = None, + blocking: bool = False, + metadata: dict[str, Any] | None = None, + ) -> None: + self.issues.append( + QualityIssue( + issue_type=issue_type, + severity=severity, + message=message, + page_num=page_num, + element_id=element_id, + region_id=region_id, + blocking=blocking, + metadata=metadata or {}, + ) + ) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +@dataclass(slots=True) +class ParsedDocument: + doc_id: str + source_path: str + file_type: str + schema_version: str = SCHEMA_VERSION + pages: list[dict[str, Any]] = field(default_factory=list) + elements: list[Element] = field(default_factory=list) + tables: list[TableObject] = field(default_factory=list) + figures: list[FigureObject] = field(default_factory=list) + chunks: list[Chunk] = field(default_factory=list) + quality_report: QualityReport = field(default_factory=QualityReport) + provenance: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + def to_markdown(self) -> str: + lines: list[str] = [] + current_page: int | None = None + for element in sorted(self.elements, key=lambda item: (item.page_num, item.reading_order or 0)): + if element.page_num != current_page: + current_page = element.page_num + if lines: + lines.append("") + lines.append(f"") + content = element.content().strip() + if content: + lines.append(content) + lines.append("") + return "\n".join(lines).strip() + ("\n" if lines else "") + + +@dataclass(slots=True) +class ParseCandidate: + parser_name: str + doc_id: str + source_path: str + file_type: str + pages: list[dict[str, Any]] = field(default_factory=list) + elements: list[Element] = field(default_factory=list) + tables: list[TableObject] = field(default_factory=list) + figures: list[FigureObject] = field(default_factory=list) + confidence: float = 0.0 + provenance: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) diff --git a/zsgdp/schema/profiles.py b/zsgdp/schema/profiles.py new file mode 100644 index 0000000000000000000000000000000000000000..8317720cfcfb912893f048135c396c846e500caa --- /dev/null +++ b/zsgdp/schema/profiles.py @@ -0,0 +1,48 @@ +"""Document profiling schema.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from zsgdp.utils import to_plain_data + + +@dataclass(slots=True) +class PageProfile: + page_num: int + digital_text_chars: int = 0 + image_area_ratio: float = 0.0 + num_images: int = 0 + num_drawings: int = 0 + font_count: int = 0 + avg_chars_per_text_block: float = 0.0 + text_block_count: int = 0 + estimated_columns: int = 1 + table_density: float = 0.0 + table_candidate_count: int = 0 + formula_density: float = 0.0 + figure_density: float = 0.0 + scanned_score: float = 0.0 + digital_text_quality: float = 0.0 + page_rotation: int = 0 + labels: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) + + +@dataclass(slots=True) +class DocumentProfile: + doc_id: str + source_path: str + file_type: str + page_count: int + extension: str + pages: list[PageProfile] = field(default_factory=list) + labels: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return to_plain_data(self) diff --git a/zsgdp/utils.py b/zsgdp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b01609b59bcda35d7ec966c9ea76ae71efb29a --- /dev/null +++ b/zsgdp/utils.py @@ -0,0 +1,72 @@ +"""Small shared utilities.""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import asdict, is_dataclass +from pathlib import Path +from typing import Any, Iterable + + +def document_id_for_path(path: str | Path) -> str: + path_obj = Path(path) + stat = path_obj.stat() + seed = f"{path_obj.resolve()}:{stat.st_size}:{int(stat.st_mtime)}" + return hashlib.sha1(seed.encode("utf-8")).hexdigest()[:16] + + +def file_type_from_path(path: str | Path) -> str: + suffix = Path(path).suffix.lower().lstrip(".") + if suffix == "pdf": + return "pdf" + if suffix in {"docx", "doc"}: + return "docx" + if suffix in {"pptx", "ppt"}: + return "pptx" + if suffix in {"xlsx", "xls", "csv"}: + return "xlsx" + if suffix in {"html", "htm"}: + return "html" + if suffix in {"png", "jpg", "jpeg", "tiff", "tif", "bmp", "webp"}: + return "image" + if suffix == "epub": + return "epub" + if suffix in {"md", "markdown"}: + return "markdown" + if suffix in {"txt", "text"}: + return "text" + return suffix or "unknown" + + +def to_plain_data(value: Any) -> Any: + if is_dataclass(value): + return {key: to_plain_data(item) for key, item in asdict(value).items()} + if isinstance(value, dict): + return {str(key): to_plain_data(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [to_plain_data(item) for item in value] + if isinstance(value, Path): + return str(value) + return value + + +def dumps_json(value: Any, *, indent: int = 2) -> str: + return json.dumps(to_plain_data(value), indent=indent, ensure_ascii=False, sort_keys=True) + + +def write_json(path: str | Path, value: Any) -> None: + Path(path).write_text(dumps_json(value) + "\n", encoding="utf-8") + + +def write_jsonl(path: str | Path, records: Iterable[Any]) -> None: + lines = [json.dumps(to_plain_data(record), ensure_ascii=False, sort_keys=True) for record in records] + Path(path).write_text("\n".join(lines) + ("\n" if lines else ""), encoding="utf-8") + + +def normalize_whitespace(text: str) -> str: + return " ".join(text.split()) + + +def clamp(value: float, low: float = 0.0, high: float = 1.0) -> float: + return max(low, min(high, value)) diff --git a/zsgdp/verify/__init__.py b/zsgdp/verify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f12a2697f722a83f0e9d46391aa46adb352da6 --- /dev/null +++ b/zsgdp/verify/__init__.py @@ -0,0 +1,14 @@ +from zsgdp.verify.chunk_readiness import verify_chunks +from zsgdp.verify.parser_disagreement import compute_parser_disagreement +from zsgdp.verify.parser_metrics import candidate_metrics, failure_metrics +from zsgdp.verify.quality_report import verify_parse +from zsgdp.verify.repair_success import compute_repair_success + +__all__ = [ + "candidate_metrics", + "compute_parser_disagreement", + "compute_repair_success", + "failure_metrics", + "verify_chunks", + "verify_parse", +] diff --git a/zsgdp/verify/chunk_readiness.py b/zsgdp/verify/chunk_readiness.py new file mode 100644 index 0000000000000000000000000000000000000000..a52eb5c657bf5845916199eda05a3324b691069f --- /dev/null +++ b/zsgdp/verify/chunk_readiness.py @@ -0,0 +1,160 @@ +"""Chunk-readiness checks.""" + +from __future__ import annotations + +from statistics import mean + +from zsgdp.schema import ParsedDocument, QualityReport + + +def verify_chunks(parsed: ParsedDocument, config: dict | None = None) -> QualityReport: + """Append chunk-readiness metrics and issues to an existing quality report.""" + + config = config or {} + chunking = config.get("chunking", {}) + report = parsed.quality_report + if not chunking.get("enabled", True): + report.metrics["chunking_enabled"] = False + return report + + chunks = parsed.chunks + target_tokens = int(chunking.get("target_tokens", 512)) + max_reasonable_tokens = max(target_tokens * 2, target_tokens + 256) + chunk_ids = [chunk.chunk_id for chunk in chunks] + duplicate_ids = sorted({chunk_id for chunk_id in chunk_ids if chunk_ids.count(chunk_id) > 1}) + parent_ids = {chunk.chunk_id for chunk in chunks if chunk.content_type == "parent"} + child_chunks = [chunk for chunk in chunks if chunk.parent_chunk_id] + orphan_children = [chunk.chunk_id for chunk in child_chunks if chunk.parent_chunk_id not in parent_ids] + missing_page_provenance = [chunk.chunk_id for chunk in chunks if chunk.page_start <= 0 or chunk.page_end < chunk.page_start] + missing_source = [chunk.chunk_id for chunk in chunks if not chunk.source_parser or chunk.source_parser == "unknown"] + oversized = [chunk.chunk_id for chunk in chunks if chunk.token_count > max_reasonable_tokens and chunk.content_type != "parent"] + empty = [chunk.chunk_id for chunk in chunks if not chunk.text.strip()] + table_chunk_ids = {table_id for chunk in chunks for table_id in chunk.table_ids} + figure_chunk_ids = {figure_id for chunk in chunks for figure_id in chunk.figure_ids} + parsed_table_ids = {table.table_id for table in parsed.tables} + parsed_figure_ids = {figure.figure_id for figure in parsed.figures} + uncovered_tables = sorted(parsed_table_ids - table_chunk_ids) + uncovered_figures = sorted(parsed_figure_ids - figure_chunk_ids) + + token_counts = [chunk.token_count for chunk in chunks] + report.metrics.update( + { + "chunking_enabled": True, + "chunk_count": len(chunks), + "parent_chunk_count": len(parent_ids), + "child_chunk_count": len(child_chunks), + "chunk_strategy_counts": _strategy_counts(parsed), + "avg_chunk_tokens": mean(token_counts) if token_counts else 0.0, + "max_chunk_tokens": max(token_counts) if token_counts else 0, + "table_chunk_coverage": _coverage(len(parsed_table_ids), len(table_chunk_ids)), + "figure_chunk_coverage": _coverage(len(parsed_figure_ids), len(figure_chunk_ids)), + } + ) + + if not chunks and parsed.elements: + report.add_issue( + "missing_chunks", + "warning", + "Parsed document has elements but no chunks.", + blocking=False, + ) + if duplicate_ids: + report.add_issue( + "duplicate_chunk_ids", + "error", + "Chunk IDs must be unique.", + blocking=True, + metadata={"chunk_ids": duplicate_ids}, + ) + if orphan_children: + report.add_issue( + "orphan_child_chunks", + "error", + "Child chunks reference missing parent chunks.", + blocking=True, + metadata={"chunk_ids": orphan_children}, + ) + if missing_page_provenance: + report.add_issue( + "chunks_without_page_provenance", + "warning", + "Some chunks have invalid page provenance.", + blocking=False, + metadata={"chunk_ids": missing_page_provenance}, + ) + if missing_source: + report.add_issue( + "chunks_without_source_parser", + "warning", + "Some chunks are missing source parser provenance.", + blocking=False, + metadata={"chunk_ids": missing_source}, + ) + if oversized: + report.add_issue( + "oversized_chunks", + "warning", + "Some child chunks exceed the configured token budget.", + blocking=False, + metadata={"chunk_ids": oversized, "max_reasonable_tokens": max_reasonable_tokens}, + ) + if empty: + report.add_issue( + "empty_chunks", + "warning", + "Some chunks are empty.", + blocking=False, + metadata={"chunk_ids": empty}, + ) + if uncovered_tables: + report.add_issue( + "tables_without_chunks", + "warning", + "Some table objects do not have table chunks.", + blocking=False, + metadata={"table_ids": uncovered_tables}, + ) + if uncovered_figures: + report.add_issue( + "figures_without_chunks", + "warning", + "Some figure objects do not have figure chunks.", + blocking=False, + metadata={"figure_ids": uncovered_figures}, + ) + + report.score = max(0.0, min(1.0, report.score - _chunk_penalty(report))) + return report + + +def _strategy_counts(parsed: ParsedDocument) -> dict[str, int]: + counts: dict[str, int] = {} + for chunk in parsed.chunks: + counts[chunk.strategy] = counts.get(chunk.strategy, 0) + 1 + return counts + + +def _coverage(total: int, covered: int) -> float: + if total == 0: + return 1.0 + return min(covered / total, 1.0) + + +def _chunk_penalty(report: QualityReport) -> float: + chunk_issue_types = { + "missing_chunks", + "duplicate_chunk_ids", + "orphan_child_chunks", + "chunks_without_page_provenance", + "chunks_without_source_parser", + "oversized_chunks", + "empty_chunks", + "tables_without_chunks", + "figures_without_chunks", + } + penalty = 0.0 + for issue in report.issues: + if issue.issue_type not in chunk_issue_types: + continue + penalty += 0.10 if issue.severity == "error" else 0.03 + return min(penalty, 0.4) diff --git a/zsgdp/verify/coverage.py b/zsgdp/verify/coverage.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb743c60e8ec34658a560ca93c2079d18cc13e9 --- /dev/null +++ b/zsgdp/verify/coverage.py @@ -0,0 +1,15 @@ +"""Coverage checks.""" + +from __future__ import annotations + +from zsgdp.schema import PageProfile, ParsedDocument + + +def page_text_chars(parsed: ParsedDocument, page_num: int) -> int: + return sum(len(element.content()) for element in parsed.elements if element.page_num == page_num) + + +def coverage_ratio(page: PageProfile, parsed: ParsedDocument) -> float: + expected = max(page.digital_text_chars, 1) + observed = page_text_chars(parsed, page.page_num) + return min(observed / expected, 1.0) diff --git a/zsgdp/verify/figure_quality.py b/zsgdp/verify/figure_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..0d5b437f8a16c34f61c79639cde4e4e637f6ce7f --- /dev/null +++ b/zsgdp/verify/figure_quality.py @@ -0,0 +1,66 @@ +"""Figure quality checks.""" + +from __future__ import annotations + +from zsgdp.schema import ParsedDocument, QualityReport + + +def verify_figure_quality(parsed: ParsedDocument, report: QualityReport, blocking_types: set[str]) -> None: + figures = parsed.figures + with_visual = sum(1 for figure in figures if figure.image_path or figure.bbox) + with_description = sum(1 for figure in figures if figure.caption or figure.vlm_description) + + report.metrics.update( + { + "figure_image_coverage": _coverage(len(figures), with_visual), + "figure_description_coverage": _coverage(len(figures), with_description), + } + ) + + for figure in figures: + if not figure.image_path and not figure.bbox: + _add( + report, + "missing_figure_region", + "warning", + f"Figure {figure.figure_id} has no image path or bounding box.", + figure.page_num, + blocking_types, + region_id=figure.figure_id, + ) + if not figure.caption and not figure.vlm_description: + _add( + report, + "missing_figure_caption", + "warning", + f"Figure {figure.figure_id} has no caption or VLM description.", + figure.page_num, + blocking_types, + region_id=figure.figure_id, + ) + + +def _coverage(total: int, observed: int) -> float: + if total == 0: + return 1.0 + return min(observed / total, 1.0) + + +def _add( + report: QualityReport, + issue_type: str, + severity: str, + message: str, + page_num: int, + blocking_types: set[str], + *, + region_id: str | None = None, +) -> None: + report.add_issue( + issue_type, + severity, + message, + page_num=page_num, + region_id=region_id, + blocking=issue_type in blocking_types, + ) diff --git a/zsgdp/verify/formula_extraction.py b/zsgdp/verify/formula_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..15c1bdde1d182fe662d5da0d97abb1f249f645f4 --- /dev/null +++ b/zsgdp/verify/formula_extraction.py @@ -0,0 +1,182 @@ +"""Formula extraction accuracy against ground-truth LaTeX strings. + +Definitions (pinned): + +- A FormulaRecord is a dict with keys: latex (str, required), page_num + (int|None), formula_id (str|None), bbox (xyxy|None). +- Predictions are matched to truths greedily, page-aware, by minimum + character error rate (CER) — the lowest-CER pair wins. +- character_error_rate (CER) = Levenshtein(predicted, truth) / max(1, len(truth)) + on whitespace-stripped strings. Capped at 1.0 so unbounded predictions can + not dominate the mean. +- accuracy = 1 - CER. exact_match = 1.0 when normalized strings are equal. +- Document-level aggregates: + - mean_cer over matched pairs (lower is better) + - mean_accuracy = 1 - mean_cer + - exact_match_rate over matched pairs + - formula_recall = matched_pairs / max(1, truth_count) + - formula_precision = matched_pairs / max(1, prediction_count) +- Empty/empty -> mean_cer=0.0, mean_accuracy=1.0 (vacuous). One side empty + -> mean_cer=1.0, mean_accuracy=0.0. +""" + +from __future__ import annotations + +import re +from typing import Any, Iterable + +FormulaRecord = dict[str, Any] + + +def compute_formula_extraction( + predictions: Iterable[FormulaRecord], + truths: Iterable[FormulaRecord], +) -> dict[str, Any]: + pred_list = [item for item in (_normalize(record) for record in predictions) if item is not None] + truth_list = [item for item in (_normalize(record) for record in truths) if item is not None] + + if not pred_list and not truth_list: + return _empty_result(vacuous=True) + if not pred_list or not truth_list: + result = _empty_result(vacuous=False) + result["prediction_count"] = len(pred_list) + result["truth_count"] = len(truth_list) + return result + + pairs: list[tuple[float, int, int, dict[str, float]]] = [] + for pred_index, prediction in enumerate(pred_list): + for truth_index, truth in enumerate(truth_list): + if prediction["page_num"] is not None and truth["page_num"] is not None: + if prediction["page_num"] != truth["page_num"]: + continue + scores = _score_pair(prediction["latex"], truth["latex"]) + pairs.append((scores["cer"], pred_index, truth_index, scores)) + + pairs.sort(key=lambda item: item[0]) + pred_taken = [False] * len(pred_list) + truth_taken = [False] * len(truth_list) + matches: list[dict[str, Any]] = [] + for _cer, pred_index, truth_index, scores in pairs: + if pred_taken[pred_index] or truth_taken[truth_index]: + continue + pred_taken[pred_index] = True + truth_taken[truth_index] = True + matches.append( + { + "prediction_index": pred_index, + "truth_index": truth_index, + "page_num": pred_list[pred_index]["page_num"] or truth_list[truth_index]["page_num"], + "cer": scores["cer"], + "accuracy": scores["accuracy"], + "exact_match": scores["exact_match"], + } + ) + + matched_pair_count = len(matches) + cer_values = [match["cer"] for match in matches] + mean_cer = sum(cer_values) / matched_pair_count if matched_pair_count else 1.0 + exact_match_rate = ( + sum(1 for match in matches if match["exact_match"] >= 0.999) / matched_pair_count + if matched_pair_count + else 0.0 + ) + return { + "prediction_count": len(pred_list), + "truth_count": len(truth_list), + "matched_pair_count": matched_pair_count, + "mean_cer": mean_cer, + "mean_accuracy": 1.0 - mean_cer, + "exact_match_rate": exact_match_rate, + "formula_precision": matched_pair_count / max(1, len(pred_list)), + "formula_recall": matched_pair_count / max(1, len(truth_list)), + "matches": matches, + "unmatched_predictions": [index for index, taken in enumerate(pred_taken) if not taken], + "unmatched_truths": [index for index, taken in enumerate(truth_taken) if not taken], + } + + +def _empty_result(vacuous: bool) -> dict[str, Any]: + if vacuous: + return { + "prediction_count": 0, + "truth_count": 0, + "matched_pair_count": 0, + "mean_cer": 0.0, + "mean_accuracy": 1.0, + "exact_match_rate": 1.0, + "formula_precision": 1.0, + "formula_recall": 1.0, + "matches": [], + "unmatched_predictions": [], + "unmatched_truths": [], + } + return { + "prediction_count": 0, + "truth_count": 0, + "matched_pair_count": 0, + "mean_cer": 1.0, + "mean_accuracy": 0.0, + "exact_match_rate": 0.0, + "formula_precision": 0.0, + "formula_recall": 0.0, + "matches": [], + "unmatched_predictions": [], + "unmatched_truths": [], + } + + +def _normalize(record: FormulaRecord | None) -> FormulaRecord | None: + if not isinstance(record, dict): + return None + latex = record.get("latex") or record.get("text") or record.get("markdown") + if not latex or not str(latex).strip(): + return None + page_num = record.get("page_num") + return { + "latex": _normalize_latex(str(latex)), + "page_num": int(page_num) if isinstance(page_num, (int, float)) else None, + "formula_id": record.get("formula_id") or record.get("element_id"), + "bbox": record.get("bbox"), + } + + +def _normalize_latex(latex: str) -> str: + text = re.sub(r"\s+", " ", latex).strip() + if text.startswith("$$") and text.endswith("$$"): + text = text[2:-2].strip() + elif text.startswith("$") and text.endswith("$"): + text = text[1:-1].strip() + return text + + +def _score_pair(predicted: str, truth: str) -> dict[str, float]: + if not truth: + return {"cer": 1.0 if predicted else 0.0, "accuracy": 0.0, "exact_match": 0.0} + distance = _levenshtein(predicted, truth) + cer = min(1.0, distance / max(1, len(truth))) + return { + "cer": cer, + "accuracy": 1.0 - cer, + "exact_match": 1.0 if predicted == truth else 0.0, + } + + +def _levenshtein(a: str, b: str) -> int: + if a == b: + return 0 + if not a: + return len(b) + if not b: + return len(a) + previous = list(range(len(b) + 1)) + for i, char_a in enumerate(a, start=1): + current = [i] + [0] * len(b) + for j, char_b in enumerate(b, start=1): + cost = 0 if char_a == char_b else 1 + current[j] = min( + previous[j] + 1, + current[j - 1] + 1, + previous[j - 1] + cost, + ) + previous = current + return previous[-1] diff --git a/zsgdp/verify/formula_quality.py b/zsgdp/verify/formula_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..5035e76f22f2b2a30735c3865c947c5d093e6420 --- /dev/null +++ b/zsgdp/verify/formula_quality.py @@ -0,0 +1,57 @@ +"""Formula quality checks.""" + +from __future__ import annotations + +from zsgdp.schema import DocumentProfile, ParsedDocument, QualityReport + + +def verify_formula_quality( + profile: DocumentProfile, + parsed: ParsedDocument, + report: QualityReport, + config: dict, + blocking_types: set[str], +) -> None: + threshold = float(config.get("routing", {}).get("formula_density_threshold", 0.15)) + formula_pages = [page for page in profile.pages if page.formula_density >= threshold] + formula_elements = [element for element in parsed.elements if element.type == "formula"] + formula_pages_with_elements = {element.page_num for element in formula_elements} + + report.metrics.update( + { + "formula_candidate_count": len(formula_elements), + "formula_heavy_page_count": len(formula_pages), + "formula_page_coverage": _coverage( + len(formula_pages), + sum(1 for page in formula_pages if page.page_num in formula_pages_with_elements), + ), + } + ) + + for page in formula_pages: + if page.page_num not in formula_pages_with_elements: + report.add_issue( + "missing_formula_regions", + "warning", + f"Page {page.page_num} looks formula-heavy but no formula elements were emitted.", + page_num=page.page_num, + blocking="missing_formula_regions" in blocking_types, + metadata={"formula_density": page.formula_density, "threshold": threshold}, + ) + + for element in formula_elements: + if not element.content().strip(): + report.add_issue( + "empty_formula", + "warning", + f"Formula element {element.element_id} is empty.", + page_num=element.page_num, + element_id=element.element_id, + blocking="empty_formula" in blocking_types, + ) + + +def _coverage(total: int, observed: int) -> float: + if total == 0: + return 1.0 + return min(observed / total, 1.0) diff --git a/zsgdp/verify/layout_f1.py b/zsgdp/verify/layout_f1.py new file mode 100644 index 0000000000000000000000000000000000000000..cfab9f7c891b5a56dd53a95c5eb5ab90718a9199 --- /dev/null +++ b/zsgdp/verify/layout_f1.py @@ -0,0 +1,152 @@ +"""Layout F1 against ground-truth bbox annotations. + +Definitions (pinned — do not move without updating tests and docs): + +- A LayoutItem is a dict with keys: bbox (xyxy tuple of 4 floats), category + (string, lowercased), page_num (int or None). Both predictions and truths + must be in this shape; adapters live in zsgdp.benchmarks.ground_truth. +- Two items are matchable if and only if they share the same page_num AND + their categories match (case-insensitive, after the GT adapter has applied + the dataset's category-name normalization). The IoU between their bboxes + must be >= iou_threshold (default 0.5). +- Matching is greedy and bipartite: predictions are sorted by descending IoU + with the best available unmatched truth. Each truth and each prediction + may participate in at most one match. This is intentionally simpler than + Hungarian matching — it matches the COCO-style mAP@0.5 convention closely + enough for layout F1 reporting. +- precision = TP / (TP + FP) over matched pairs, where FP = unmatched + predictions; recall = TP / (TP + FN) where FN = unmatched truths. +- f1 = harmonic mean of precision and recall, with 0.0 when both sides are + empty (no signal) and 1.0 when one side is empty AND the other is too. +- Two parallel scores are emitted: `class_aware` (the definition above) and + `class_agnostic` (same matcher, ignoring category equality). The pair lets + callers see whether errors are localization or classification. +""" + +from __future__ import annotations + +from typing import Any, Iterable + +LayoutItem = dict[str, Any] + + +def compute_layout_f1( + predictions: Iterable[LayoutItem], + truths: Iterable[LayoutItem], + *, + iou_threshold: float = 0.5, +) -> dict[str, Any]: + pred_list = [_normalize(item) for item in predictions] + pred_list = [item for item in pred_list if item is not None] + truth_list = [_normalize(item) for item in truths] + truth_list = [item for item in truth_list if item is not None] + + class_aware = _score(pred_list, truth_list, iou_threshold=iou_threshold, category_aware=True) + class_agnostic = _score(pred_list, truth_list, iou_threshold=iou_threshold, category_aware=False) + per_category = _per_category(pred_list, truth_list, iou_threshold=iou_threshold) + return { + "iou_threshold": iou_threshold, + "prediction_count": len(pred_list), + "truth_count": len(truth_list), + "class_aware": class_aware, + "class_agnostic": class_agnostic, + "per_category": per_category, + } + + +def _normalize(item: LayoutItem | None) -> LayoutItem | None: + if not isinstance(item, dict): + return None + bbox = _coerce_bbox(item.get("bbox")) + if bbox is None: + return None + category = str(item.get("category", "")).strip().lower() + page_num = item.get("page_num") + return { + "bbox": bbox, + "category": category, + "page_num": int(page_num) if isinstance(page_num, (int, float)) else None, + } + + +def _coerce_bbox(raw: Any) -> tuple[float, float, float, float] | None: + if not raw or not isinstance(raw, (list, tuple)) or len(raw) < 4: + return None + try: + x0, y0, x1, y1 = (float(raw[0]), float(raw[1]), float(raw[2]), float(raw[3])) + except (TypeError, ValueError): + return None + if x1 <= x0 or y1 <= y0: + return None + return (x0, y0, x1, y1) + + +def _score( + predictions: list[LayoutItem], + truths: list[LayoutItem], + *, + iou_threshold: float, + category_aware: bool, +) -> dict[str, float]: + if not predictions and not truths: + return {"precision": 1.0, "recall": 1.0, "f1": 1.0, "tp": 0, "fp": 0, "fn": 0} + if not predictions: + return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "tp": 0, "fp": 0, "fn": len(truths)} + if not truths: + return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "tp": 0, "fp": len(predictions), "fn": 0} + + truth_taken = [False] * len(truths) + pairs: list[tuple[float, int, int]] = [] + for prediction_index, prediction in enumerate(predictions): + for truth_index, truth in enumerate(truths): + if prediction["page_num"] != truth["page_num"]: + continue + if category_aware and prediction["category"] != truth["category"]: + continue + iou = _iou(prediction["bbox"], truth["bbox"]) + if iou >= iou_threshold: + pairs.append((iou, prediction_index, truth_index)) + + pairs.sort(key=lambda item: item[0], reverse=True) + pred_taken = [False] * len(predictions) + tp = 0 + for _iou_value, pred_index, truth_index in pairs: + if pred_taken[pred_index] or truth_taken[truth_index]: + continue + pred_taken[pred_index] = True + truth_taken[truth_index] = True + tp += 1 + + fp = sum(1 for taken in pred_taken if not taken) + fn = sum(1 for taken in truth_taken if not taken) + precision = tp / (tp + fp) if (tp + fp) else 0.0 + recall = tp / (tp + fn) if (tp + fn) else 0.0 + f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0 + return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn} + + +def _per_category( + predictions: list[LayoutItem], + truths: list[LayoutItem], + *, + iou_threshold: float, +) -> dict[str, dict[str, float]]: + categories = sorted({item["category"] for item in predictions} | {item["category"] for item in truths}) + out: dict[str, dict[str, float]] = {} + for category in categories: + pred_subset = [item for item in predictions if item["category"] == category] + truth_subset = [item for item in truths if item["category"] == category] + out[category] = _score(pred_subset, truth_subset, iou_threshold=iou_threshold, category_aware=False) + return out + + +def _iou(a: tuple[float, float, float, float], b: tuple[float, float, float, float]) -> float: + ix0, iy0 = max(a[0], b[0]), max(a[1], b[1]) + ix1, iy1 = min(a[2], b[2]), min(a[3], b[3]) + if ix1 <= ix0 or iy1 <= iy0: + return 0.0 + intersection = (ix1 - ix0) * (iy1 - iy0) + area_a = max(0.0, (a[2] - a[0]) * (a[3] - a[1])) + area_b = max(0.0, (b[2] - b[0]) * (b[3] - b[1])) + union = area_a + area_b - intersection + return intersection / union if union > 0 else 0.0 diff --git a/zsgdp/verify/parser_disagreement.py b/zsgdp/verify/parser_disagreement.py new file mode 100644 index 0000000000000000000000000000000000000000..d12ab85c1b7a0db5adea62ba64c48c268c9b41bb --- /dev/null +++ b/zsgdp/verify/parser_disagreement.py @@ -0,0 +1,68 @@ +"""Parser disagreement-rate metric. + +Definitions (pinned so the metric does not drift as the verifier evolves): + +- parser_pair_count = C(successful_candidate_count, 2). One "comparison surface" + per unordered parser pair that produced a candidate parse. +- disagreement_rate = conflict_count / parser_pair_count when parser_pair_count + > 0; 0.0 when the document only had one (or zero) successful parsers. +- conflict_count is taken from merge.conflict_report (text-coverage gaps, + table/figure count disagreements, reading-order disagreements, table-structure + disagreements). We do not redefine "conflict" here — we count what the merger + already detected, so the metric stays consistent with the repair triggers. +- disagreement_by_type: counts grouped by conflict.type. +- disagreement_by_parser_pair: counts grouped by sorted (parser_a, parser_b). +""" + +from __future__ import annotations + +from itertools import combinations +from typing import Any + + +def compute_parser_disagreement( + conflict_report: dict[str, Any] | None, + parser_metrics: dict[str, dict[str, Any]] | None, +) -> dict[str, Any]: + successful_parsers = _successful_parsers(parser_metrics) + parser_pair_count = len(list(combinations(successful_parsers, 2))) + conflicts = list((conflict_report or {}).get("conflicts") or []) + + by_type: dict[str, int] = {} + by_pair: dict[str, int] = {} + for conflict in conflicts: + conflict_type = str(conflict.get("type", "unknown")) + by_type[conflict_type] = by_type.get(conflict_type, 0) + 1 + pair_key = _pair_key(conflict.get("parsers")) + if pair_key is not None: + by_pair[pair_key] = by_pair.get(pair_key, 0) + 1 + + rate = (len(conflicts) / parser_pair_count) if parser_pair_count else 0.0 + return { + "candidate_count": len(successful_parsers), + "parser_pair_count": parser_pair_count, + "conflict_count": len(conflicts), + "disagreement_rate": rate, + "disagreement_by_type": dict(sorted(by_type.items())), + "disagreement_by_parser_pair": dict(sorted(by_pair.items())), + "successful_parsers": successful_parsers, + } + + +def _successful_parsers(parser_metrics: dict[str, dict[str, Any]] | None) -> list[str]: + if not parser_metrics: + return [] + return sorted( + parser_name + for parser_name, metrics in parser_metrics.items() + if isinstance(metrics, dict) and not metrics.get("failed") + ) + + +def _pair_key(parsers: Any) -> str | None: + if not isinstance(parsers, list) or len(parsers) < 2: + return None + names = sorted(str(name) for name in parsers[:2] if name) + if len(names) < 2: + return None + return f"{names[0]}|{names[1]}" diff --git a/zsgdp/verify/parser_metrics.py b/zsgdp/verify/parser_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..5d165b08460c60fa3aba806a562eb9e2ad12fd1c --- /dev/null +++ b/zsgdp/verify/parser_metrics.py @@ -0,0 +1,54 @@ +"""Per-parser candidate metrics.""" + +from __future__ import annotations + +from zsgdp.schema import DocumentProfile, ParseCandidate +from zsgdp.verify.table_quality import markdown_table_is_valid + + +def candidate_metrics(candidate: ParseCandidate, profile: DocumentProfile, *, elapsed_seconds: float | None = None) -> dict: + expected_chars = _expected_chars(candidate, profile) + text_chars = sum(len(element.content()) for element in candidate.elements) + table_markdown_count = sum(1 for table in candidate.tables if table.markdown) + valid_table_count = sum(1 for table in candidate.tables if markdown_table_is_valid(table.markdown)) + pages = sorted({page.get("page_num", 1) for page in candidate.pages} or {element.page_num for element in candidate.elements}) + metrics = { + "parser": candidate.parser_name, + "confidence": candidate.confidence, + "elapsed_seconds": elapsed_seconds, + "page_count": len(pages), + "pages": pages, + "element_count": len(candidate.elements), + "table_count": len(candidate.tables), + "figure_count": len(candidate.figures), + "text_chars": text_chars, + "expected_text_chars": expected_chars, + "text_coverage_ratio": min(text_chars / expected_chars, 1.0) if expected_chars else (1.0 if text_chars else 0.0), + "table_markdown_count": table_markdown_count, + "valid_table_count": valid_table_count, + "valid_table_ratio": valid_table_count / table_markdown_count if table_markdown_count else 1.0, + "has_bboxes": any(element.bbox for element in candidate.elements), + "has_page_images": any("rendered_page" in page for page in candidate.pages), + "source_path": candidate.source_path, + } + return metrics + + +def failure_metrics(parser_name: str, profile: DocumentProfile, message: str, *, elapsed_seconds: float | None = None) -> dict: + return { + "parser": parser_name, + "failed": True, + "error": message, + "elapsed_seconds": elapsed_seconds, + "page_count": profile.page_count, + "source_path": profile.source_path, + } + + +def _expected_chars(candidate: ParseCandidate, profile: DocumentProfile) -> int: + candidate_pages = {int(page.get("page_num", 1)) for page in candidate.pages} + if not candidate_pages: + candidate_pages = {element.page_num for element in candidate.elements} + if not candidate_pages: + return sum(page.digital_text_chars for page in profile.pages) + return sum(page.digital_text_chars for page in profile.pages if page.page_num in candidate_pages) diff --git a/zsgdp/verify/quality_report.py b/zsgdp/verify/quality_report.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2eb4b5b8490f70de846d99f79b01cf4d637530 --- /dev/null +++ b/zsgdp/verify/quality_report.py @@ -0,0 +1,112 @@ +"""Build parse quality reports.""" + +from __future__ import annotations + +from zsgdp.schema import DocumentProfile, ParsedDocument, QualityReport +from zsgdp.verify.coverage import coverage_ratio, page_text_chars +from zsgdp.verify.figure_quality import verify_figure_quality +from zsgdp.verify.formula_quality import verify_formula_quality +from zsgdp.verify.reading_order import reading_order_is_monotonic +from zsgdp.verify.table_quality import markdown_table_is_valid + + +def verify_parse(profile: DocumentProfile, parsed: ParsedDocument, config: dict | None = None) -> QualityReport: + config = config or {} + quality_config = config.get("quality", {}) + blocking_types = set(quality_config.get("blocking_failures", [])) + report = QualityReport(score=1.0, issues=list(parsed.quality_report.issues), metrics={}) + + total_expected = 0 + total_observed = 0 + for page in profile.pages: + expected = page.digital_text_chars + observed = page_text_chars(parsed, page.page_num) + total_expected += expected + total_observed += observed + + if expected > 0 and observed == 0: + _add(report, "empty_page", "error", f"Page {page.page_num} has no parsed text.", page.page_num, blocking_types) + elif expected >= 200: + ratio = coverage_ratio(page, parsed) + if ratio < 0.55: + _add( + report, + "missing_text_coverage", + "warning", + f"Page {page.page_num} text coverage is low: {ratio:.2f}.", + page.page_num, + blocking_types, + metadata={"coverage_ratio": ratio}, + ) + + if not reading_order_is_monotonic(parsed, page.page_num): + _add( + report, + "reading_order_failure", + "warning", + f"Page {page.page_num} reading-order values are not monotonic.", + page.page_num, + blocking_types, + ) + + for table in parsed.tables: + if table.markdown and not markdown_table_is_valid(table.markdown): + _add( + report, + "invalid_table", + "warning", + f"Table {table.table_id} is not a valid Markdown table.", + min(table.page_nums or [1]), + blocking_types, + region_id=table.table_id, + ) + + verify_figure_quality(parsed, report, blocking_types) + verify_formula_quality(profile, parsed, report, config, blocking_types) + + doc_coverage = (total_observed / total_expected) if total_expected else (1.0 if parsed.elements else 0.0) + report.metrics.update( + { + "expected_text_chars": total_expected, + "observed_text_chars": total_observed, + "document_text_coverage": min(doc_coverage, 1.0), + "element_count": len(parsed.elements), + "table_count": len(parsed.tables), + "figure_count": len(parsed.figures), + } + ) + penalty = _score_penalty(report) + report.score = max(0.0, min(1.0, report.metrics["document_text_coverage"] - penalty)) + return report + + +def _add( + report: QualityReport, + issue_type: str, + severity: str, + message: str, + page_num: int | None, + blocking_types: set[str], + *, + region_id: str | None = None, + metadata: dict | None = None, +) -> None: + report.add_issue( + issue_type, + severity, + message, + page_num=page_num, + region_id=region_id, + blocking=issue_type in blocking_types, + metadata=metadata, + ) + + +def _score_penalty(report: QualityReport) -> float: + penalty = 0.0 + for issue in report.issues: + if issue.severity == "error": + penalty += 0.15 + elif issue.severity == "warning": + penalty += 0.04 + return min(penalty, 0.6) diff --git a/zsgdp/verify/reading_order.py b/zsgdp/verify/reading_order.py new file mode 100644 index 0000000000000000000000000000000000000000..268fc3ee6da032399cc07899a698ca1274fea80e --- /dev/null +++ b/zsgdp/verify/reading_order.py @@ -0,0 +1,14 @@ +"""Reading-order checks.""" + +from __future__ import annotations + +from zsgdp.schema import ParsedDocument + + +def reading_order_is_monotonic(parsed: ParsedDocument, page_num: int) -> bool: + orders = [ + element.reading_order + for element in parsed.elements + if element.page_num == page_num and element.reading_order is not None + ] + return orders == sorted(orders) diff --git a/zsgdp/verify/repair_success.py b/zsgdp/verify/repair_success.py new file mode 100644 index 0000000000000000000000000000000000000000..596a023c8dc2ed96c35f06f140a050ce99a99344 --- /dev/null +++ b/zsgdp/verify/repair_success.py @@ -0,0 +1,103 @@ +"""Repair-loop success metrics. + +Definitions (pinned — do not change without a corresponding test update): + +- An "issue identity" is the tuple (issue_type, page_num, region_id). Two + issues are considered the same if and only if their identities are equal. + This is a deliberate, slightly fuzzy notion: a re-detected issue with the + same type/page/region after repair is treated as unresolved, even if the + message string changed. Newly detected issues with the same type but a + different region count as new (regression candidates), not unresolved. + +- pre_repair_blocking = identities of issues in pre_repair_quality.issues + where blocking is True. +- post_repair_blocking = identities of issues in the *current* quality_report + where blocking is True. +- resolved = pre_repair_blocking - post_repair_blocking. +- regressed = post_repair_blocking - pre_repair_blocking + (i.e. blocking issues introduced by repair, not present pre-repair). + +- repair_resolution_rate = len(resolved) / len(pre_repair_blocking) when + pre_repair_blocking is non-empty; 1.0 (vacuous success) when there were no + blocking failures pre-repair. +- repair_regression_rate = len(regressed) / max(1, len(pre_repair_blocking)). + We use the pre-repair denominator deliberately so the rate is interpretable + as "regressions per pre-existing blocking issue." + +The same shape (resolved/regressed/rate pair) is also computed over *all* +issues regardless of blocking flag, exposed under the *_any keys, so callers +can audit how repair affected non-blocking warnings as well. +""" + +from __future__ import annotations + +from typing import Any + + +def compute_repair_success( + pre_repair_quality: dict[str, Any] | None, + post_repair_quality: dict[str, Any] | None, + repair_iterations: list[dict[str, Any]] | None, +) -> dict[str, Any]: + pre_issues = _issues(pre_repair_quality) + post_issues = _issues(post_repair_quality) + + pre_blocking = {_identity(issue) for issue in pre_issues if issue.get("blocking")} + post_blocking = {_identity(issue) for issue in post_issues if issue.get("blocking")} + resolved_blocking = pre_blocking - post_blocking + regressed_blocking = post_blocking - pre_blocking + + pre_any = {_identity(issue) for issue in pre_issues} + post_any = {_identity(issue) for issue in post_issues} + resolved_any = pre_any - post_any + regressed_any = post_any - pre_any + + history = list(repair_iterations or []) + score_delta = _score_delta(history, post_repair_quality) + iteration_count = len(history) + actions_per_iteration = [len(item.get("actions") or []) for item in history] + total_actions = sum(actions_per_iteration) + + return { + "iteration_count": iteration_count, + "total_actions": total_actions, + "actions_per_iteration": actions_per_iteration, + "score_delta": score_delta, + "pre_repair_blocking_count": len(pre_blocking), + "post_repair_blocking_count": len(post_blocking), + "resolved_blocking_count": len(resolved_blocking), + "regressed_blocking_count": len(regressed_blocking), + "repair_resolution_rate": (len(resolved_blocking) / len(pre_blocking)) if pre_blocking else 1.0, + "repair_regression_rate": (len(regressed_blocking) / len(pre_blocking)) if pre_blocking else 0.0, + "pre_repair_issue_count": len(pre_any), + "post_repair_issue_count": len(post_any), + "resolved_any_count": len(resolved_any), + "regressed_any_count": len(regressed_any), + "repair_resolution_rate_any": (len(resolved_any) / len(pre_any)) if pre_any else 1.0, + "repair_regression_rate_any": (len(regressed_any) / len(pre_any)) if pre_any else 0.0, + } + + +def _issues(quality: dict[str, Any] | None) -> list[dict[str, Any]]: + if not isinstance(quality, dict): + return [] + issues = quality.get("issues") + return [issue for issue in issues if isinstance(issue, dict)] if isinstance(issues, list) else [] + + +def _identity(issue: dict[str, Any]) -> tuple[str, Any, Any]: + return ( + str(issue.get("issue_type", "")), + issue.get("page_num"), + issue.get("region_id"), + ) + + +def _score_delta(history: list[dict[str, Any]], post_repair_quality: dict[str, Any] | None) -> float: + final_score = 0.0 + if isinstance(post_repair_quality, dict): + final_score = float(post_repair_quality.get("score", 0.0)) + if not history: + return 0.0 + first_before = float(history[0].get("before_score", 0.0)) + return final_score - first_before diff --git a/zsgdp/verify/retrieval.py b/zsgdp/verify/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..46e0bf2aea005d66690ceb47487c2670d61adc1c --- /dev/null +++ b/zsgdp/verify/retrieval.py @@ -0,0 +1,87 @@ +"""Retrieval-readiness metrics for RAG-style benchmarking. + +Definitions (pinned): + +- A QueryResult is (retrieved: list[str], truths: list[str]). Both lists hold + chunk_ids. `retrieved` is in rank order, best first. `truths` is the set of + acceptable chunk_ids; a query is satisfied if any truth appears in the + retrieved prefix. +- recall_at_k = (# queries with at least one truth in retrieved[:k]) / total + queries. Computed for each k in k_values (default 1, 3, 5). +- citation_accuracy_at_k = same as recall_at_k under this contract — the + query is "cited correctly" if any source chunk_id appears in the retrieved + top-k. We expose it under a second name so the benchmark JSON makes the + spec mapping explicit (citation accuracy in §18.2). +- mean_reciprocal_rank: MRR = mean(1/rank_of_first_truth) when at least one + truth appears in retrieved; 0.0 when no truth appears at any rank. +- empty queries (no truths) are skipped; if every query is empty, all metrics + return 1.0 (vacuous). +""" + +from __future__ import annotations + +from typing import Any, Iterable, Sequence + +QueryResult = tuple[Sequence[str], Sequence[str]] + + +def compute_retrieval_metrics( + queries: Iterable[QueryResult], + *, + k_values: Sequence[int] = (1, 3, 5), +) -> dict[str, Any]: + queries_list = [(_dedupe(retrieved), set(truths)) for retrieved, truths in queries] + queries_list = [(retrieved, truths) for retrieved, truths in queries_list if truths] + + if not queries_list: + return _empty_result(k_values) + + recall_counts = {k: 0 for k in k_values} + citation_counts = {k: 0 for k in k_values} + reciprocal_ranks: list[float] = [] + + for retrieved, truths in queries_list: + first_hit_rank: int | None = None + for rank, chunk_id in enumerate(retrieved, start=1): + if chunk_id in truths: + first_hit_rank = rank + break + for k in k_values: + top_k = retrieved[:k] + hit = any(chunk_id in truths for chunk_id in top_k) + if hit: + recall_counts[k] += 1 + citation_counts[k] += 1 + reciprocal_ranks.append(1.0 / first_hit_rank if first_hit_rank else 0.0) + + total = len(queries_list) + return { + "query_count": total, + "k_values": list(k_values), + "recall_at_k": {k: recall_counts[k] / total for k in k_values}, + "citation_accuracy_at_k": {k: citation_counts[k] / total for k in k_values}, + "mean_reciprocal_rank": sum(reciprocal_ranks) / total, + "hits_at_k_counts": dict(recall_counts), + } + + +def _dedupe(retrieved: Sequence[str]) -> list[str]: + seen: set[str] = set() + out: list[str] = [] + for chunk_id in retrieved: + if chunk_id in seen: + continue + seen.add(chunk_id) + out.append(chunk_id) + return out + + +def _empty_result(k_values: Sequence[int]) -> dict[str, Any]: + return { + "query_count": 0, + "k_values": list(k_values), + "recall_at_k": {k: 1.0 for k in k_values}, + "citation_accuracy_at_k": {k: 1.0 for k in k_values}, + "mean_reciprocal_rank": 1.0, + "hits_at_k_counts": {k: 0 for k in k_values}, + } diff --git a/zsgdp/verify/table_quality.py b/zsgdp/verify/table_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..19d7f3156022fa7e9f1f27f780b1dfdcc6a65ee3 --- /dev/null +++ b/zsgdp/verify/table_quality.py @@ -0,0 +1,20 @@ +"""Table quality checks.""" + +from __future__ import annotations + + +def markdown_table_is_valid(markdown: str | None) -> bool: + if not markdown: + return False + rows = [row.strip() for row in markdown.splitlines() if row.strip()] + pipe_rows = [row for row in rows if row.startswith("|") and row.endswith("|")] + if len(pipe_rows) < 2: + return False + widths = [len([cell for cell in row.strip("|").split("|")]) for row in pipe_rows] + content_rows = [row for row in pipe_rows if not _is_separator_row(row)] + return len(set(widths)) == 1 and widths[0] >= 2 and len(content_rows) >= 2 + + +def _is_separator_row(row: str) -> bool: + cells = [cell.strip() for cell in row.strip("|").split("|")] + return bool(cells) and all(cell.strip(":") == "---" for cell in cells if cell) diff --git a/zsgdp/verify/table_structure.py b/zsgdp/verify/table_structure.py new file mode 100644 index 0000000000000000000000000000000000000000..5a3e963c14a15b6c31e69afd70927338fcc32861 --- /dev/null +++ b/zsgdp/verify/table_structure.py @@ -0,0 +1,218 @@ +"""Table structure similarity against ground-truth tables. + +Definitions (pinned): + +- A TableRecord is a dict with keys: rows (list[list[str]] of normalized cell + strings), page_num (int|None), table_id (str|None), bbox (xyxy|None). +- Predictions are matched to truths greedily, page-aware. The match score + blends shape similarity (row/col count) with cell-content overlap, so a + prediction with the same dimensions and overlapping headers wins over one + with extra phantom rows. +- For each matched pair we compute: + - shape_similarity = 1 - mean(|rows_p - rows_t|/max_rows, |cols_p - cols_t|/max_cols) + bounded to [0, 1]. + - cell_content_f1 = F1 over the multiset of normalized cell strings. + Multiset (not set) so duplicate "0.00" cells in financial tables count. + - score = 0.5 * shape_similarity + 0.5 * cell_content_f1. +- Document-level aggregates: + - mean_table_score = mean(score) over matched pairs. + - table_match_rate = matched_pairs / max(len(predictions), len(truths)). + - table_count_delta = len(predictions) - len(truths). +- Edge cases: empty/empty -> mean_table_score=1.0 (vacuous); one side empty + -> 0.0 with all unmatched recorded as fp/fn. +""" + +from __future__ import annotations + +import re +from typing import Any, Iterable + +TableRecord = dict[str, Any] + + +def compute_table_structure_score( + predictions: Iterable[TableRecord], + truths: Iterable[TableRecord], +) -> dict[str, Any]: + pred_list = [item for item in (_normalize(record) for record in predictions) if item is not None] + truth_list = [item for item in (_normalize(record) for record in truths) if item is not None] + + if not pred_list and not truth_list: + return { + "prediction_count": 0, + "truth_count": 0, + "matched_pair_count": 0, + "table_match_rate": 1.0, + "mean_table_score": 1.0, + "mean_shape_similarity": 1.0, + "mean_cell_content_f1": 1.0, + "table_count_delta": 0, + "matches": [], + "unmatched_predictions": [], + "unmatched_truths": [], + } + + pairs: list[tuple[float, int, int, dict[str, float]]] = [] + for pred_index, prediction in enumerate(pred_list): + for truth_index, truth in enumerate(truth_list): + if prediction["page_num"] is not None and truth["page_num"] is not None: + if prediction["page_num"] != truth["page_num"]: + continue + scores = _score_pair(prediction, truth) + if scores["match_score"] <= 0.0: + continue + pairs.append((scores["match_score"], pred_index, truth_index, scores)) + + pairs.sort(key=lambda item: item[0], reverse=True) + pred_taken = [False] * len(pred_list) + truth_taken = [False] * len(truth_list) + matches: list[dict[str, Any]] = [] + for _match_score, pred_index, truth_index, scores in pairs: + if pred_taken[pred_index] or truth_taken[truth_index]: + continue + pred_taken[pred_index] = True + truth_taken[truth_index] = True + matches.append( + { + "prediction_index": pred_index, + "truth_index": truth_index, + "page_num": pred_list[pred_index]["page_num"] or truth_list[truth_index]["page_num"], + "score": scores["score"], + "shape_similarity": scores["shape_similarity"], + "cell_content_f1": scores["cell_content_f1"], + "predicted_shape": [len(pred_list[pred_index]["rows"]), _max_cols(pred_list[pred_index]["rows"])], + "truth_shape": [len(truth_list[truth_index]["rows"]), _max_cols(truth_list[truth_index]["rows"])], + } + ) + + matched_pair_count = len(matches) + denominator = max(len(pred_list), len(truth_list)) + mean_table_score = (sum(match["score"] for match in matches) / denominator) if denominator else 1.0 + mean_shape = (sum(match["shape_similarity"] for match in matches) / matched_pair_count) if matched_pair_count else 0.0 + mean_cell = (sum(match["cell_content_f1"] for match in matches) / matched_pair_count) if matched_pair_count else 0.0 + + return { + "prediction_count": len(pred_list), + "truth_count": len(truth_list), + "matched_pair_count": matched_pair_count, + "table_match_rate": (matched_pair_count / denominator) if denominator else 1.0, + "mean_table_score": mean_table_score, + "mean_shape_similarity": mean_shape, + "mean_cell_content_f1": mean_cell, + "table_count_delta": len(pred_list) - len(truth_list), + "matches": matches, + "unmatched_predictions": [index for index, taken in enumerate(pred_taken) if not taken], + "unmatched_truths": [index for index, taken in enumerate(truth_taken) if not taken], + } + + +def markdown_to_rows(markdown: str | None) -> list[list[str]]: + if not markdown: + return [] + rows: list[list[str]] = [] + for line in markdown.splitlines(): + stripped = line.strip() + if not (stripped.startswith("|") and stripped.endswith("|")): + continue + cells = [_normalize_cell(cell) for cell in stripped.strip("|").split("|")] + if cells and all(re.fullmatch(r":?-{3,}:?", cell) for cell in cells if cell): + continue + rows.append(cells) + return rows + + +def html_to_rows(html: str | None) -> list[list[str]]: + if not html: + return [] + rows: list[list[str]] = [] + for tr_match in re.finditer(r"]*>(.*?)", html, flags=re.IGNORECASE | re.DOTALL): + tr_body = tr_match.group(1) + cells = [ + _normalize_cell(_strip_html(cell.group(1))) + for cell in re.finditer(r"]*>(.*?)", tr_body, flags=re.IGNORECASE | re.DOTALL) + ] + if cells: + rows.append(cells) + return rows + + +def _strip_html(text: str) -> str: + return re.sub(r"<[^>]+>", " ", text) + + +def _normalize_cell(cell: str) -> str: + return " ".join(cell.replace("\xa0", " ").split()).strip().lower() + + +def _normalize(record: TableRecord | None) -> TableRecord | None: + if not isinstance(record, dict): + return None + rows = record.get("rows") + if rows is None: + rows = markdown_to_rows(record.get("markdown")) + if not rows: + rows = html_to_rows(record.get("html")) + rows = [[_normalize_cell(cell) for cell in row] for row in rows if row] + if not rows: + return None + page_num = record.get("page_num") + return { + "rows": rows, + "page_num": int(page_num) if isinstance(page_num, (int, float)) else None, + "table_id": record.get("table_id"), + "bbox": record.get("bbox"), + } + + +def _score_pair(prediction: TableRecord, truth: TableRecord) -> dict[str, float]: + shape_similarity = _shape_similarity(prediction["rows"], truth["rows"]) + cell_content_f1 = _multiset_f1(_cell_multiset(prediction["rows"]), _cell_multiset(truth["rows"])) + score = 0.5 * shape_similarity + 0.5 * cell_content_f1 + # Prefer matches that share at least one cell content token to break ties. + overlap_bonus = 0.0 + if cell_content_f1 > 0: + overlap_bonus = 0.01 + return { + "shape_similarity": shape_similarity, + "cell_content_f1": cell_content_f1, + "score": score, + "match_score": score + overlap_bonus, + } + + +def _shape_similarity(a: list[list[str]], b: list[list[str]]) -> float: + rows_a = len(a) + rows_b = len(b) + cols_a = _max_cols(a) + cols_b = _max_cols(b) + max_rows = max(rows_a, rows_b, 1) + max_cols = max(cols_a, cols_b, 1) + row_diff = abs(rows_a - rows_b) / max_rows + col_diff = abs(cols_a - cols_b) / max_cols + return max(0.0, 1.0 - 0.5 * (row_diff + col_diff)) + + +def _max_cols(rows: list[list[str]]) -> int: + return max((len(row) for row in rows), default=0) + + +def _cell_multiset(rows: list[list[str]]) -> list[str]: + return [cell for row in rows for cell in row if cell] + + +def _multiset_f1(predicted: list[str], truth: list[str]) -> float: + if not predicted and not truth: + return 1.0 + if not predicted or not truth: + return 0.0 + truth_remaining = list(truth) + tp = 0 + for cell in predicted: + if cell in truth_remaining: + truth_remaining.remove(cell) + tp += 1 + fp = len(predicted) - tp + fn = len(truth_remaining) + precision = tp / (tp + fp) if (tp + fp) else 0.0 + recall = tp / (tp + fn) if (tp + fn) else 0.0 + return (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0