yaml-bert / yaml_bert /cache.py
vimalk78's picture
Initial app: Gradio missing-field suggester (v6.1 model)
222a479 verified
Raw
History Blame Contribute Delete
3.36 kB
"""Cache linearized YAML documents to disk.
Linearizes all documents once (with multiprocessing), saves to pickle.
Both VocabBuilder and YamlDataset load from cache — no re-parsing.
Usage:
from yaml_bert.cache import build_or_load_cache
docs = build_or_load_cache(
dataset_name="substratusai/the-stack-yaml-k8s",
cache_path="output_v3/doc_cache.pkl",
max_docs=276520,
)
"""
from __future__ import annotations
import os
import pickle
import time
from multiprocessing import Pool, cpu_count
from typing import Any
from yaml_bert.annotator import DomainAnnotator
from yaml_bert.linearizer import YamlLinearizer
from yaml_bert.types import YamlNode
def _linearize_one(yaml_content: str) -> list[YamlNode] | None:
"""Linearize a single YAML string. Used by multiprocessing pool."""
try:
linearizer = YamlLinearizer()
annotator = DomainAnnotator()
nodes: list[YamlNode] = linearizer.linearize(yaml_content)
if nodes:
annotator.annotate(nodes)
return nodes
except Exception:
pass
return None
def build_or_load_cache(
dataset_name: str,
cache_path: str,
max_docs: int | None = None,
num_workers: int | None = None,
) -> list[list[YamlNode]]:
"""Build or load cached linearized documents.
Args:
dataset_name: HuggingFace dataset ID
cache_path: Path to save/load the cache pickle
max_docs: Max documents to process (None = all)
num_workers: Number of parallel workers (None = cpu_count)
Returns:
List of linearized documents (each is a list of YamlNode)
"""
if os.path.exists(cache_path):
print(f"Loading cached documents from {cache_path}")
start: float = time.time()
with open(cache_path, "rb") as f:
documents: list[list[YamlNode]] = pickle.load(f)
print(f"Loaded {len(documents):,} documents in {time.time() - start:.1f}s")
return documents
from datasets import load_dataset
print(f"Building document cache from {dataset_name}")
ds = load_dataset(dataset_name, split="train")
total: int = len(ds) if max_docs is None else min(max_docs, len(ds))
print(f"Linearizing {total:,} documents with {num_workers or cpu_count()} workers...")
# Extract YAML content strings
yaml_contents: list[str] = [ds[i]["content"] for i in range(total)]
# Parallel linearization with progress bar
from tqdm import tqdm
start = time.time()
workers: int = num_workers or cpu_count()
with Pool(workers) as pool:
results: list[list[YamlNode] | None] = list(tqdm(
pool.imap(_linearize_one, yaml_contents, chunksize=500),
total=len(yaml_contents),
desc="Linearizing",
))
documents = [doc for doc in results if doc is not None]
skipped: int = sum(1 for doc in results if doc is None)
elapsed: float = time.time() - start
print(f"Linearized {len(documents):,} documents in {elapsed:.1f}s ({skipped} skipped)")
# Save cache
os.makedirs(os.path.dirname(cache_path) or ".", exist_ok=True)
with open(cache_path, "wb") as f:
pickle.dump(documents, f)
cache_size: float = os.path.getsize(cache_path) / 1024 / 1024
print(f"Cache saved: {cache_path} ({cache_size:.1f} MB)")
return documents