| import pandas as pd | |
| from src.graph.graph_builder import build_edge_index, build_edge_features, build_labels | |
| from src.graph.node_features import build_node_features | |
| from src.graph.temporal_split import temporal_split | |
| def build_graph_dataset(df: pd.DataFrame, users: pd.DataFrame): | |
| edge_index = build_edge_index(df) | |
| edge_attr = build_edge_features(df) | |
| y = build_labels(df) | |
| X = build_node_features(df, users) | |
| # Raw timestamps for TGN time encoding | |
| timestamps = df.sort_values("timestamp").reset_index(drop=True)["timestamp"].values | |
| train_mask, val_mask, test_mask, _ = temporal_split(df) | |
| return { | |
| "edge_index": edge_index, | |
| "edge_attr": edge_attr, | |
| "timestamps": timestamps, | |
| "x": X, | |
| "y": y, | |
| "train_mask": train_mask, | |
| "val_mask": val_mask, | |
| "test_mask": test_mask, | |
| } |