Spaces:
Running
Running
| """ | |
| Title: Graph attention network (GAT) for node classification | |
| Author: [akensert](https://github.com/akensert) | |
| Date created: 2021/09/13 | |
| Last modified: 2021/12/26 | |
| Description: An implementation of a Graph Attention Network (GAT) for node classification. | |
| Accelerator: GPU | |
| """ | |
| """ | |
| ## Introduction | |
| [Graph neural networks](https://en.wikipedia.org/wiki/Graph_neural_network) | |
| is the preferred neural network architecture for processing data structured as | |
| graphs (for example, social networks or molecule structures), yielding | |
| better results than fully-connected networks or convolutional networks. | |
| In this tutorial, we will implement a specific graph neural network known as a | |
| [Graph Attention Network](https://arxiv.org/abs/1710.10903) (GAT) to predict labels of | |
| scientific papers based on what type of papers cite them (using the | |
| [Cora](https://linqs.soe.ucsc.edu/data) dataset). | |
| ### References | |
| For more information on GAT, see the original paper | |
| [Graph Attention Networks](https://arxiv.org/abs/1710.10903) as well as | |
| [DGL's Graph Attention Networks](https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html) | |
| documentation. | |
| """ | |
| """ | |
| ### Import packages | |
| """ | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| pd.set_option("display.max_columns", 6) | |
| pd.set_option("display.max_rows", 6) | |
| np.random.seed(2) | |
| """ | |
| ## Obtain the dataset | |
| The preparation of the [Cora dataset](https://linqs.soe.ucsc.edu/data) follows that of the | |
| [Node classification with Graph Neural Networks](https://keras.io/examples/graph/gnn_citations/) | |
| tutorial. Refer to this tutorial for more details on the dataset and exploratory data analysis. | |
| In brief, the Cora dataset consists of two files: `cora.cites` which contains *directed links* (citations) between | |
| papers; and `cora.content` which contains *features* of the corresponding papers and one | |
| of seven labels (the *subject* of the paper). | |
| """ | |
| zip_file = keras.utils.get_file( | |
| fname="cora.tgz", | |
| origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz", | |
| extract=True, | |
| ) | |
| data_dir = os.path.join(os.path.dirname(zip_file), "cora") | |
| citations = pd.read_csv( | |
| os.path.join(data_dir, "cora.cites"), | |
| sep="\t", | |
| header=None, | |
| names=["target", "source"], | |
| ) | |
| papers = pd.read_csv( | |
| os.path.join(data_dir, "cora.content"), | |
| sep="\t", | |
| header=None, | |
| names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"], | |
| ) | |
| class_values = sorted(papers["subject"].unique()) | |
| class_idx = {name: id for id, name in enumerate(class_values)} | |
| paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))} | |
| papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name]) | |
| citations["source"] = citations["source"].apply(lambda name: paper_idx[name]) | |
| citations["target"] = citations["target"].apply(lambda name: paper_idx[name]) | |
| papers["subject"] = papers["subject"].apply(lambda value: class_idx[value]) | |
| print(citations) | |
| print(papers) | |
| """ | |
| ### Split the dataset | |
| """ | |
| # Obtain random indices | |
| random_indices = np.random.permutation(range(papers.shape[0])) | |
| # 50/50 split | |
| train_data = papers.iloc[random_indices[: len(random_indices) // 2]] | |
| test_data = papers.iloc[random_indices[len(random_indices) // 2 :]] | |
| """ | |
| ### Prepare the graph data | |
| """ | |
| # Obtain paper indices which will be used to gather node states | |
| # from the graph later on when training the model | |
| train_indices = train_data["paper_id"].to_numpy() | |
| test_indices = test_data["paper_id"].to_numpy() | |
| # Obtain ground truth labels corresponding to each paper_id | |
| train_labels = train_data["subject"].to_numpy() | |
| test_labels = test_data["subject"].to_numpy() | |
| # Define graph, namely an edge tensor and a node feature tensor | |
| edges = tf.convert_to_tensor(citations[["target", "source"]]) | |
| node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1]) | |
| # Print shapes of the graph | |
| print("Edges shape:\t\t", edges.shape) | |
| print("Node features shape:", node_states.shape) | |
| """ | |
| ## Build the model | |
| GAT takes as input a graph (namely an edge tensor and a node feature tensor) and | |
| outputs \[updated\] node states. The node states are, for each target node, neighborhood | |
| aggregated information of *N*-hops (where *N* is decided by the number of layers of the | |
| GAT). Importantly, in contrast to the | |
| [graph convolutional network](https://arxiv.org/abs/1609.02907) (GCN) | |
| the GAT makes use of attention mechanisms | |
| to aggregate information from neighboring nodes (or *source nodes*). In other words, instead of simply | |
| averaging/summing node states from source nodes (*source papers*) to the target node (*target papers*), | |
| GAT first applies normalized attention scores to each source node state and then sums. | |
| """ | |
| """ | |
| ### (Multi-head) graph attention layer | |
| The GAT model implements multi-head graph attention layers. The `MultiHeadGraphAttention` | |
| layer is simply a concatenation (or averaging) of multiple graph attention layers | |
| (`GraphAttention`), each with separate learnable weights `W`. The `GraphAttention` layer | |
| does the following: | |
| Consider inputs node states `h^{l}` which are linearly transformed by `W^{l}`, resulting in `z^{l}`. | |
| For each target node: | |
| 1. Computes pair-wise attention scores `a^{l}^{T}(z^{l}_{i}||z^{l}_{j})` for all `j`, | |
| resulting in `e_{ij}` (for all `j`). | |
| `||` denotes a concatenation, `_{i}` corresponds to the target node, and `_{j}` | |
| corresponds to a given 1-hop neighbor/source node. | |
| 2. Normalizes `e_{ij}` via softmax, so as the sum of incoming edges' attention scores | |
| to the target node (`sum_{k}{e_{norm}_{ik}}`) will add up to 1. | |
| 3. Applies attention scores `e_{norm}_{ij}` to `z_{j}` | |
| and adds it to the new target node state `h^{l+1}_{i}`, for all `j`. | |
| """ | |
| class GraphAttention(layers.Layer): | |
| def __init__( | |
| self, | |
| units, | |
| kernel_initializer="glorot_uniform", | |
| kernel_regularizer=None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.units = units | |
| self.kernel_initializer = keras.initializers.get(kernel_initializer) | |
| self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) | |
| def build(self, input_shape): | |
| self.kernel = self.add_weight( | |
| shape=(input_shape[0][-1], self.units), | |
| trainable=True, | |
| initializer=self.kernel_initializer, | |
| regularizer=self.kernel_regularizer, | |
| name="kernel", | |
| ) | |
| self.kernel_attention = self.add_weight( | |
| shape=(self.units * 2, 1), | |
| trainable=True, | |
| initializer=self.kernel_initializer, | |
| regularizer=self.kernel_regularizer, | |
| name="kernel_attention", | |
| ) | |
| self.built = True | |
| def call(self, inputs): | |
| node_states, edges = inputs | |
| # Linearly transform node states | |
| node_states_transformed = tf.matmul(node_states, self.kernel) | |
| # (1) Compute pair-wise attention scores | |
| node_states_expanded = tf.gather(node_states_transformed, edges) | |
| node_states_expanded = tf.reshape( | |
| node_states_expanded, (tf.shape(edges)[0], -1) | |
| ) | |
| attention_scores = tf.nn.leaky_relu( | |
| tf.matmul(node_states_expanded, self.kernel_attention) | |
| ) | |
| attention_scores = tf.squeeze(attention_scores, -1) | |
| # (2) Normalize attention scores | |
| attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2)) | |
| attention_scores_sum = tf.math.unsorted_segment_sum( | |
| data=attention_scores, | |
| segment_ids=edges[:, 0], | |
| num_segments=tf.reduce_max(edges[:, 0]) + 1, | |
| ) | |
| attention_scores_sum = tf.repeat( | |
| attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32")) | |
| ) | |
| attention_scores_norm = attention_scores / attention_scores_sum | |
| # (3) Gather node states of neighbors, apply attention scores and aggregate | |
| node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1]) | |
| out = tf.math.unsorted_segment_sum( | |
| data=node_states_neighbors * attention_scores_norm[:, tf.newaxis], | |
| segment_ids=edges[:, 0], | |
| num_segments=tf.shape(node_states)[0], | |
| ) | |
| return out | |
| class MultiHeadGraphAttention(layers.Layer): | |
| def __init__(self, units, num_heads=8, merge_type="concat", **kwargs): | |
| super().__init__(**kwargs) | |
| self.num_heads = num_heads | |
| self.merge_type = merge_type | |
| self.attention_layers = [GraphAttention(units) for _ in range(num_heads)] | |
| def call(self, inputs): | |
| atom_features, pair_indices = inputs | |
| # Obtain outputs from each attention head | |
| outputs = [ | |
| attention_layer([atom_features, pair_indices]) | |
| for attention_layer in self.attention_layers | |
| ] | |
| # Concatenate or average the node states from each head | |
| if self.merge_type == "concat": | |
| outputs = tf.concat(outputs, axis=-1) | |
| else: | |
| outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1) | |
| # Activate and return node states | |
| return tf.nn.relu(outputs) | |
| """ | |
| ### Implement training logic with custom `train_step`, `test_step`, and `predict_step` methods | |
| Notice, the GAT model operates on the entire graph (namely, `node_states` and | |
| `edges`) in all phases (training, validation and testing). Hence, `node_states` and | |
| `edges` are passed to the constructor of the `keras.Model` and used as attributes. | |
| The difference between the phases are the indices (and labels), which gathers | |
| certain outputs (`tf.gather(outputs, indices)`). | |
| """ | |
| class GraphAttentionNetwork(keras.Model): | |
| def __init__( | |
| self, | |
| node_states, | |
| edges, | |
| hidden_units, | |
| num_heads, | |
| num_layers, | |
| output_dim, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.node_states = node_states | |
| self.edges = edges | |
| self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu") | |
| self.attention_layers = [ | |
| MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers) | |
| ] | |
| self.output_layer = layers.Dense(output_dim) | |
| def call(self, inputs): | |
| node_states, edges = inputs | |
| x = self.preprocess(node_states) | |
| for attention_layer in self.attention_layers: | |
| x = attention_layer([x, edges]) + x | |
| outputs = self.output_layer(x) | |
| return outputs | |
| def train_step(self, data): | |
| indices, labels = data | |
| with tf.GradientTape() as tape: | |
| # Forward pass | |
| outputs = self([self.node_states, self.edges]) | |
| # Compute loss | |
| loss = self.compiled_loss(labels, tf.gather(outputs, indices)) | |
| # Compute gradients | |
| grads = tape.gradient(loss, self.trainable_weights) | |
| # Apply gradients (update weights) | |
| optimizer.apply_gradients(zip(grads, self.trainable_weights)) | |
| # Update metric(s) | |
| self.compiled_metrics.update_state(labels, tf.gather(outputs, indices)) | |
| return {m.name: m.result() for m in self.metrics} | |
| def predict_step(self, data): | |
| indices = data | |
| # Forward pass | |
| outputs = self([self.node_states, self.edges]) | |
| # Compute probabilities | |
| return tf.nn.softmax(tf.gather(outputs, indices)) | |
| def test_step(self, data): | |
| indices, labels = data | |
| # Forward pass | |
| outputs = self([self.node_states, self.edges]) | |
| # Compute loss | |
| loss = self.compiled_loss(labels, tf.gather(outputs, indices)) | |
| # Update metric(s) | |
| self.compiled_metrics.update_state(labels, tf.gather(outputs, indices)) | |
| return {m.name: m.result() for m in self.metrics} | |
| """ | |
| ### Train and evaluate | |
| """ | |
| # Define hyper-parameters | |
| HIDDEN_UNITS = 100 | |
| NUM_HEADS = 8 | |
| NUM_LAYERS = 3 | |
| OUTPUT_DIM = len(class_values) | |
| NUM_EPOCHS = 100 | |
| BATCH_SIZE = 256 | |
| VALIDATION_SPLIT = 0.1 | |
| LEARNING_RATE = 3e-1 | |
| MOMENTUM = 0.9 | |
| loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
| optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM) | |
| accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc") | |
| early_stopping = keras.callbacks.EarlyStopping( | |
| monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True | |
| ) | |
| # Build model | |
| gat_model = GraphAttentionNetwork( | |
| node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM | |
| ) | |
| # Compile model | |
| gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn]) | |
| gat_model.fit( | |
| x=train_indices, | |
| y=train_labels, | |
| validation_split=VALIDATION_SPLIT, | |
| batch_size=BATCH_SIZE, | |
| epochs=NUM_EPOCHS, | |
| callbacks=[early_stopping], | |
| verbose=2, | |
| ) | |
| _, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0) | |
| print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%") | |
| """ | |
| ### Predict (probabilities) | |
| """ | |
| test_probs = gat_model.predict(x=test_indices) | |
| mapping = {v: k for (k, v) in class_idx.items()} | |
| for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])): | |
| print(f"Example {i+1}: {mapping[label]}") | |
| for j, c in zip(probs, class_idx.keys()): | |
| print(f"\tProbability of {c: <24} = {j*100:7.3f}%") | |
| print("---" * 20) | |
| """ | |
| ## Conclusions | |
| The results look OK! The GAT model seems to correctly predict the subjects of the papers, | |
| based on what they cite, about 80% of the time. Further improvements could be | |
| made by fine-tuning the hyper-parameters of the GAT. For instance, try changing the number of layers, | |
| the number of hidden units, or the optimizer/learning rate; add regularization (e.g., dropout); | |
| or modify the preprocessing step. We could also try to implement *self-loops* | |
| (i.e., paper X cites paper X) and/or make the graph *undirected*. | |
| """ | |