File size: 866 Bytes
a3682cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import pandas as pd

from src.graph.graph_builder import build_edge_index, build_edge_features, build_labels
from src.graph.node_features import build_node_features
from src.graph.temporal_split import temporal_split


def build_graph_dataset(df: pd.DataFrame, users: pd.DataFrame):
    edge_index = build_edge_index(df)
    edge_attr = build_edge_features(df)
    y = build_labels(df)

    X = build_node_features(df, users)

    # Raw timestamps for TGN time encoding
    timestamps = df.sort_values("timestamp").reset_index(drop=True)["timestamp"].values

    train_mask, val_mask, test_mask, _ = temporal_split(df)

    return {
        "edge_index": edge_index,
        "edge_attr": edge_attr,
        "timestamps": timestamps,
        "x": X,
        "y": y,
        "train_mask": train_mask,
        "val_mask": val_mask,
        "test_mask": test_mask,
    }