wandler67's picture
Deploy H2EPR-Bench Explorer Space
704baa5 verified
Raw
History Blame Contribute Delete
3.76 kB
from __future__ import annotations
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Any
from .constants import (
CATALOG_JSONL,
CATALOG_PARQUET,
FINALCASCADE_JSONL,
FINALCASCADE_SUMMARY_PARQUET,
LOCAL_DATASET_ENV,
PUBLIC_DATASET_REPO,
STAGES_JSONL,
STAGES_PARQUET,
)
class SimpleTable:
"""Small fallback table used when pandas is unavailable in local checks."""
def __init__(self, rows: list[dict[str, Any]]):
self._rows = rows
def __len__(self) -> int:
return len(self._rows)
def to_dict(self, orient: str = "records") -> list[dict[str, Any]]:
if orient != "records":
raise ValueError("SimpleTable only supports orient='records'")
return list(self._rows)
def _as_local_root(local_dataset_dir: Path | str | None = None) -> Path | None:
value = local_dataset_dir or os.environ.get(LOCAL_DATASET_ENV)
if not value:
return None
return Path(value).expanduser().resolve()
def resolve_dataset_file(filename: str, local_dataset_dir: Path | str | None = None) -> Path:
local_root = _as_local_root(local_dataset_dir)
if local_root is not None:
path = (local_root / filename).resolve()
if not path.is_relative_to(local_root):
raise ValueError(f"Refusing to read outside local dataset root: {filename}")
if not path.exists():
raise FileNotFoundError(path)
return path
from huggingface_hub import hf_hub_download
return Path(
hf_hub_download(
repo_id=PUBLIC_DATASET_REPO,
repo_type="dataset",
filename=filename,
)
)
def load_jsonl_rows(filename: str, local_dataset_dir: Path | str | None = None) -> list[dict[str, Any]]:
path = resolve_dataset_file(filename, local_dataset_dir=local_dataset_dir)
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def read_event_graph_from_jsonl(
event_id: str, local_dataset_dir: Path | str | None = None
) -> dict[str, Any]:
path = resolve_dataset_file(FINALCASCADE_JSONL, local_dataset_dir=local_dataset_dir)
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if not line.strip():
continue
row = json.loads(line)
if row.get("event_id") == event_id:
return row
raise KeyError(f"Event graph not found: {event_id}")
def _read_table(filename: str, fallback_jsonl: str, local_dataset_dir: Path | str | None = None):
pd = None
try:
import pandas as pandas_module
pd = pandas_module
except ImportError:
pass
try:
path = resolve_dataset_file(filename, local_dataset_dir=local_dataset_dir)
if pd is None:
raise ImportError("pandas is unavailable")
return pd.read_parquet(path)
except (FileNotFoundError, ImportError, ValueError):
rows = load_jsonl_rows(fallback_jsonl, local_dataset_dir=local_dataset_dir)
if pd is None:
return SimpleTable(rows)
return pd.DataFrame(rows)
@lru_cache(maxsize=1)
def load_catalog():
return _read_table(CATALOG_PARQUET, CATALOG_JSONL)
@lru_cache(maxsize=1)
def load_stages():
return _read_table(STAGES_PARQUET, STAGES_JSONL)
@lru_cache(maxsize=1)
def load_finalcascade_summary():
return _read_table(FINALCASCADE_SUMMARY_PARQUET, CATALOG_JSONL)
@lru_cache(maxsize=128)
def load_event_graph(event_id: str) -> dict[str, Any]:
return read_event_graph_from_jsonl(event_id)