| """Cache linearized YAML documents to disk. |
| |
| Linearizes all documents once (with multiprocessing), saves to pickle. |
| Both VocabBuilder and YamlDataset load from cache — no re-parsing. |
| |
| Usage: |
| from yaml_bert.cache import build_or_load_cache |
| |
| docs = build_or_load_cache( |
| dataset_name="substratusai/the-stack-yaml-k8s", |
| cache_path="output_v3/doc_cache.pkl", |
| max_docs=276520, |
| ) |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import pickle |
| import time |
| from multiprocessing import Pool, cpu_count |
| from typing import Any |
|
|
| from yaml_bert.annotator import DomainAnnotator |
| from yaml_bert.linearizer import YamlLinearizer |
| from yaml_bert.types import YamlNode |
|
|
|
|
| def _linearize_one(yaml_content: str) -> list[YamlNode] | None: |
| """Linearize a single YAML string. Used by multiprocessing pool.""" |
| try: |
| linearizer = YamlLinearizer() |
| annotator = DomainAnnotator() |
| nodes: list[YamlNode] = linearizer.linearize(yaml_content) |
| if nodes: |
| annotator.annotate(nodes) |
| return nodes |
| except Exception: |
| pass |
| return None |
|
|
|
|
| def build_or_load_cache( |
| dataset_name: str, |
| cache_path: str, |
| max_docs: int | None = None, |
| num_workers: int | None = None, |
| ) -> list[list[YamlNode]]: |
| """Build or load cached linearized documents. |
| |
| Args: |
| dataset_name: HuggingFace dataset ID |
| cache_path: Path to save/load the cache pickle |
| max_docs: Max documents to process (None = all) |
| num_workers: Number of parallel workers (None = cpu_count) |
| |
| Returns: |
| List of linearized documents (each is a list of YamlNode) |
| """ |
| if os.path.exists(cache_path): |
| print(f"Loading cached documents from {cache_path}") |
| start: float = time.time() |
| with open(cache_path, "rb") as f: |
| documents: list[list[YamlNode]] = pickle.load(f) |
| print(f"Loaded {len(documents):,} documents in {time.time() - start:.1f}s") |
| return documents |
|
|
| from datasets import load_dataset |
|
|
| print(f"Building document cache from {dataset_name}") |
| ds = load_dataset(dataset_name, split="train") |
|
|
| total: int = len(ds) if max_docs is None else min(max_docs, len(ds)) |
| print(f"Linearizing {total:,} documents with {num_workers or cpu_count()} workers...") |
|
|
| |
| yaml_contents: list[str] = [ds[i]["content"] for i in range(total)] |
|
|
| |
| from tqdm import tqdm |
|
|
| start = time.time() |
| workers: int = num_workers or cpu_count() |
| with Pool(workers) as pool: |
| results: list[list[YamlNode] | None] = list(tqdm( |
| pool.imap(_linearize_one, yaml_contents, chunksize=500), |
| total=len(yaml_contents), |
| desc="Linearizing", |
| )) |
|
|
| documents = [doc for doc in results if doc is not None] |
| skipped: int = sum(1 for doc in results if doc is None) |
| elapsed: float = time.time() - start |
|
|
| print(f"Linearized {len(documents):,} documents in {elapsed:.1f}s ({skipped} skipped)") |
|
|
| |
| os.makedirs(os.path.dirname(cache_path) or ".", exist_ok=True) |
| with open(cache_path, "wb") as f: |
| pickle.dump(documents, f) |
|
|
| cache_size: float = os.path.getsize(cache_path) / 1024 / 1024 |
| print(f"Cache saved: {cache_path} ({cache_size:.1f} MB)") |
|
|
| return documents |
|
|