Spaces:
Sleeping
Sleeping
File size: 2,969 Bytes
88d2f2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | """Convenience loader: outputs/ground_truth/*.json -> ReferenceTranslation rows.
This is a thin wrapper around ``db_ingestion.ingest_reference_translations``
that exposes a clean single-purpose API for callers (T4 judges, demo
scripts, eval harness).
CLI:
python -m polyglot_alpha.corpus.reference_loader \\
--path outputs/ground_truth/
"""
from __future__ import annotations
import argparse
import asyncio
import logging
from pathlib import Path
from typing import Optional
from polyglot_alpha.corpus.db_ingestion import (
DEFAULT_REFERENCES_DIR,
IngestStats,
ingest_reference_translations,
)
from polyglot_alpha.persistence import db as persistence_db
from polyglot_alpha.persistence import init_db
from polyglot_alpha.persistence.models import ReferenceTranslation
from sqlmodel import Session
def _read_session() -> Session:
"""Internal helper: read-only session that keeps attributes alive."""
return Session(persistence_db.engine, expire_on_commit=False)
LOGGER = logging.getLogger(__name__)
async def load_references(
path: Path = DEFAULT_REFERENCES_DIR,
*,
ensure_schema: bool = True,
) -> IngestStats:
"""Ingest ground-truth JSON files into ``reference_translations``."""
if ensure_schema:
init_db()
return await ingest_reference_translations(path)
def get_reference(sample_id: int) -> Optional[ReferenceTranslation]:
"""Synchronous DB fetch for a single reference translation."""
with _read_session() as session:
row = session.get(ReferenceTranslation, sample_id)
if row is not None:
session.expunge(row)
return row
def list_references(limit: int = 100) -> list[ReferenceTranslation]:
"""Synchronous DB fetch for all reference translations."""
from sqlalchemy import select
with _read_session() as session:
stmt = (
select(ReferenceTranslation)
.order_by(ReferenceTranslation.sample_id.asc())
.limit(limit)
)
rows = list(session.execute(stmt).scalars())
for r in rows:
session.expunge(r)
return rows
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Load human-verified reference translations into the DB."
)
parser.add_argument(
"--path",
type=Path,
default=DEFAULT_REFERENCES_DIR,
help="Directory of JSON files or a single JSONL file.",
)
return parser
def main(argv: Optional[list[str]] = None) -> int:
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
args = _build_parser().parse_args(argv)
stats = asyncio.run(load_references(args.path))
LOGGER.info(
"References loaded: %d inserted, %d updated, %d skipped",
stats.inserted,
stats.updated,
stats.skipped,
)
return 0
if __name__ == "__main__":
raise SystemExit(main())
|