File size: 3,757 Bytes
704baa5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | from __future__ import annotations
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Any
from .constants import (
CATALOG_JSONL,
CATALOG_PARQUET,
FINALCASCADE_JSONL,
FINALCASCADE_SUMMARY_PARQUET,
LOCAL_DATASET_ENV,
PUBLIC_DATASET_REPO,
STAGES_JSONL,
STAGES_PARQUET,
)
class SimpleTable:
"""Small fallback table used when pandas is unavailable in local checks."""
def __init__(self, rows: list[dict[str, Any]]):
self._rows = rows
def __len__(self) -> int:
return len(self._rows)
def to_dict(self, orient: str = "records") -> list[dict[str, Any]]:
if orient != "records":
raise ValueError("SimpleTable only supports orient='records'")
return list(self._rows)
def _as_local_root(local_dataset_dir: Path | str | None = None) -> Path | None:
value = local_dataset_dir or os.environ.get(LOCAL_DATASET_ENV)
if not value:
return None
return Path(value).expanduser().resolve()
def resolve_dataset_file(filename: str, local_dataset_dir: Path | str | None = None) -> Path:
local_root = _as_local_root(local_dataset_dir)
if local_root is not None:
path = (local_root / filename).resolve()
if not path.is_relative_to(local_root):
raise ValueError(f"Refusing to read outside local dataset root: {filename}")
if not path.exists():
raise FileNotFoundError(path)
return path
from huggingface_hub import hf_hub_download
return Path(
hf_hub_download(
repo_id=PUBLIC_DATASET_REPO,
repo_type="dataset",
filename=filename,
)
)
def load_jsonl_rows(filename: str, local_dataset_dir: Path | str | None = None) -> list[dict[str, Any]]:
path = resolve_dataset_file(filename, local_dataset_dir=local_dataset_dir)
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def read_event_graph_from_jsonl(
event_id: str, local_dataset_dir: Path | str | None = None
) -> dict[str, Any]:
path = resolve_dataset_file(FINALCASCADE_JSONL, local_dataset_dir=local_dataset_dir)
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if not line.strip():
continue
row = json.loads(line)
if row.get("event_id") == event_id:
return row
raise KeyError(f"Event graph not found: {event_id}")
def _read_table(filename: str, fallback_jsonl: str, local_dataset_dir: Path | str | None = None):
pd = None
try:
import pandas as pandas_module
pd = pandas_module
except ImportError:
pass
try:
path = resolve_dataset_file(filename, local_dataset_dir=local_dataset_dir)
if pd is None:
raise ImportError("pandas is unavailable")
return pd.read_parquet(path)
except (FileNotFoundError, ImportError, ValueError):
rows = load_jsonl_rows(fallback_jsonl, local_dataset_dir=local_dataset_dir)
if pd is None:
return SimpleTable(rows)
return pd.DataFrame(rows)
@lru_cache(maxsize=1)
def load_catalog():
return _read_table(CATALOG_PARQUET, CATALOG_JSONL)
@lru_cache(maxsize=1)
def load_stages():
return _read_table(STAGES_PARQUET, STAGES_JSONL)
@lru_cache(maxsize=1)
def load_finalcascade_summary():
return _read_table(FINALCASCADE_SUMMARY_PARQUET, CATALOG_JSONL)
@lru_cache(maxsize=128)
def load_event_graph(event_id: str) -> dict[str, Any]:
return read_event_graph_from_jsonl(event_id)
|