File size: 3,757 Bytes
704baa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Any

from .constants import (
    CATALOG_JSONL,
    CATALOG_PARQUET,
    FINALCASCADE_JSONL,
    FINALCASCADE_SUMMARY_PARQUET,
    LOCAL_DATASET_ENV,
    PUBLIC_DATASET_REPO,
    STAGES_JSONL,
    STAGES_PARQUET,
)


class SimpleTable:
    """Small fallback table used when pandas is unavailable in local checks."""

    def __init__(self, rows: list[dict[str, Any]]):
        self._rows = rows

    def __len__(self) -> int:
        return len(self._rows)

    def to_dict(self, orient: str = "records") -> list[dict[str, Any]]:
        if orient != "records":
            raise ValueError("SimpleTable only supports orient='records'")
        return list(self._rows)


def _as_local_root(local_dataset_dir: Path | str | None = None) -> Path | None:
    value = local_dataset_dir or os.environ.get(LOCAL_DATASET_ENV)
    if not value:
        return None
    return Path(value).expanduser().resolve()


def resolve_dataset_file(filename: str, local_dataset_dir: Path | str | None = None) -> Path:
    local_root = _as_local_root(local_dataset_dir)
    if local_root is not None:
        path = (local_root / filename).resolve()
        if not path.is_relative_to(local_root):
            raise ValueError(f"Refusing to read outside local dataset root: {filename}")
        if not path.exists():
            raise FileNotFoundError(path)
        return path

    from huggingface_hub import hf_hub_download

    return Path(
        hf_hub_download(
            repo_id=PUBLIC_DATASET_REPO,
            repo_type="dataset",
            filename=filename,
        )
    )


def load_jsonl_rows(filename: str, local_dataset_dir: Path | str | None = None) -> list[dict[str, Any]]:
    path = resolve_dataset_file(filename, local_dataset_dir=local_dataset_dir)
    rows: list[dict[str, Any]] = []
    with path.open("r", encoding="utf-8") as handle:
        for line in handle:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def read_event_graph_from_jsonl(
    event_id: str, local_dataset_dir: Path | str | None = None
) -> dict[str, Any]:
    path = resolve_dataset_file(FINALCASCADE_JSONL, local_dataset_dir=local_dataset_dir)
    with path.open("r", encoding="utf-8") as handle:
        for line in handle:
            if not line.strip():
                continue
            row = json.loads(line)
            if row.get("event_id") == event_id:
                return row
    raise KeyError(f"Event graph not found: {event_id}")


def _read_table(filename: str, fallback_jsonl: str, local_dataset_dir: Path | str | None = None):
    pd = None
    try:
        import pandas as pandas_module

        pd = pandas_module
    except ImportError:
        pass
    try:
        path = resolve_dataset_file(filename, local_dataset_dir=local_dataset_dir)
        if pd is None:
            raise ImportError("pandas is unavailable")
        return pd.read_parquet(path)
    except (FileNotFoundError, ImportError, ValueError):
        rows = load_jsonl_rows(fallback_jsonl, local_dataset_dir=local_dataset_dir)
        if pd is None:
            return SimpleTable(rows)
        return pd.DataFrame(rows)


@lru_cache(maxsize=1)
def load_catalog():
    return _read_table(CATALOG_PARQUET, CATALOG_JSONL)


@lru_cache(maxsize=1)
def load_stages():
    return _read_table(STAGES_PARQUET, STAGES_JSONL)


@lru_cache(maxsize=1)
def load_finalcascade_summary():
    return _read_table(FINALCASCADE_SUMMARY_PARQUET, CATALOG_JSONL)


@lru_cache(maxsize=128)
def load_event_graph(event_id: str) -> dict[str, Any]:
    return read_event_graph_from_jsonl(event_id)