from __future__ import annotations import json import os from functools import lru_cache from pathlib import Path from typing import Any from .constants import ( CATALOG_JSONL, CATALOG_PARQUET, FINALCASCADE_JSONL, FINALCASCADE_SUMMARY_PARQUET, LOCAL_DATASET_ENV, PUBLIC_DATASET_REPO, STAGES_JSONL, STAGES_PARQUET, ) class SimpleTable: """Small fallback table used when pandas is unavailable in local checks.""" def __init__(self, rows: list[dict[str, Any]]): self._rows = rows def __len__(self) -> int: return len(self._rows) def to_dict(self, orient: str = "records") -> list[dict[str, Any]]: if orient != "records": raise ValueError("SimpleTable only supports orient='records'") return list(self._rows) def _as_local_root(local_dataset_dir: Path | str | None = None) -> Path | None: value = local_dataset_dir or os.environ.get(LOCAL_DATASET_ENV) if not value: return None return Path(value).expanduser().resolve() def resolve_dataset_file(filename: str, local_dataset_dir: Path | str | None = None) -> Path: local_root = _as_local_root(local_dataset_dir) if local_root is not None: path = (local_root / filename).resolve() if not path.is_relative_to(local_root): raise ValueError(f"Refusing to read outside local dataset root: {filename}") if not path.exists(): raise FileNotFoundError(path) return path from huggingface_hub import hf_hub_download return Path( hf_hub_download( repo_id=PUBLIC_DATASET_REPO, repo_type="dataset", filename=filename, ) ) def load_jsonl_rows(filename: str, local_dataset_dir: Path | str | None = None) -> list[dict[str, Any]]: path = resolve_dataset_file(filename, local_dataset_dir=local_dataset_dir) rows: list[dict[str, Any]] = [] with path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if line: rows.append(json.loads(line)) return rows def read_event_graph_from_jsonl( event_id: str, local_dataset_dir: Path | str | None = None ) -> dict[str, Any]: path = resolve_dataset_file(FINALCASCADE_JSONL, local_dataset_dir=local_dataset_dir) with path.open("r", encoding="utf-8") as handle: for line in handle: if not line.strip(): continue row = json.loads(line) if row.get("event_id") == event_id: return row raise KeyError(f"Event graph not found: {event_id}") def _read_table(filename: str, fallback_jsonl: str, local_dataset_dir: Path | str | None = None): pd = None try: import pandas as pandas_module pd = pandas_module except ImportError: pass try: path = resolve_dataset_file(filename, local_dataset_dir=local_dataset_dir) if pd is None: raise ImportError("pandas is unavailable") return pd.read_parquet(path) except (FileNotFoundError, ImportError, ValueError): rows = load_jsonl_rows(fallback_jsonl, local_dataset_dir=local_dataset_dir) if pd is None: return SimpleTable(rows) return pd.DataFrame(rows) @lru_cache(maxsize=1) def load_catalog(): return _read_table(CATALOG_PARQUET, CATALOG_JSONL) @lru_cache(maxsize=1) def load_stages(): return _read_table(STAGES_PARQUET, STAGES_JSONL) @lru_cache(maxsize=1) def load_finalcascade_summary(): return _read_table(FINALCASCADE_SUMMARY_PARQUET, CATALOG_JSONL) @lru_cache(maxsize=128) def load_event_graph(event_id: str) -> dict[str, Any]: return read_event_graph_from_jsonl(event_id)