| | import torch.optim as optim |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from copy import deepcopy |
| | from deeprobust.graph import utils |
| | import torch |
| |
|
| |
|
| | class BaseModel(nn.Module): |
| | def __init__(self): |
| | super(BaseModel, self).__init__() |
| | pass |
| |
|
| | def fit(self, pyg_data, train_iters=1000, initialize=True, verbose=False, patience=100, **kwargs): |
| | if initialize: |
| | self.initialize() |
| |
|
| | |
| | self.data = pyg_data.to(self.device) |
| | |
| | self.train_with_early_stopping(train_iters, patience, verbose) |
| |
|
| | def finetune(self, edge_index, edge_weight, feat=None, train_iters=10, verbose=True): |
| | if verbose: |
| | print(f'=== finetuning {self.name} model ===') |
| | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) |
| | labels = self.data.y |
| | if feat is None: |
| | x = self.data.x |
| | else: |
| | x = feat |
| | train_mask, val_mask = self.data.train_mask, self.data.val_mask |
| | best_loss_val = 100 |
| | best_acc_val = 0 |
| | for i in range(train_iters): |
| | self.train() |
| | optimizer.zero_grad() |
| | output = self.forward(x, edge_index, edge_weight) |
| | loss_train = F.nll_loss(output[train_mask], labels[train_mask]) |
| | loss_train.backward() |
| | optimizer.step() |
| |
|
| | if verbose and i % 50 == 0: |
| | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) |
| |
|
| | self.eval() |
| | with torch.no_grad(): |
| | output = self.forward(x, edge_index) |
| | loss_val = F.nll_loss(output[val_mask], labels[val_mask]) |
| | acc_val = utils.accuracy(output[val_mask], labels[val_mask]) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | if best_acc_val < acc_val: |
| | best_acc_val = acc_val |
| | best_output = output |
| | weights = deepcopy(self.state_dict()) |
| |
|
| | print('best_acc_val:', best_acc_val.item()) |
| | self.load_state_dict(weights) |
| | return best_output |
| |
|
| |
|
| | def _fit_with_val(self, pyg_data, train_iters=1000, initialize=True, verbose=False, **kwargs): |
| | if initialize: |
| | self.initialize() |
| |
|
| | |
| | self.data = pyg_data.to(self.device) |
| | if verbose: |
| | print(f'=== training {self.name} model ===') |
| | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) |
| |
|
| | labels = self.data.y |
| | train_mask, val_mask = self.data.train_mask, self.data.val_mask |
| |
|
| | x, edge_index = self.data.x, self.data.edge_index |
| | for i in range(train_iters): |
| | self.train() |
| | optimizer.zero_grad() |
| | output = self.forward(x, edge_index) |
| | loss_train = F.nll_loss(output[train_mask+val_mask], labels[train_mask+val_mask]) |
| | loss_train.backward() |
| | optimizer.step() |
| |
|
| | if verbose and i % 50 == 0: |
| | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) |
| |
|
| | def fit_with_val(self, pyg_data, train_iters=1000, initialize=True, patience=100, verbose=False, **kwargs): |
| | if initialize: |
| | self.initialize() |
| |
|
| | self.data = pyg_data.to(self.device) |
| | self.data.train_mask = self.data.train_mask + self.data.val1_mask |
| | self.data.val_mask = self.data.val2_mask |
| | self.train_with_early_stopping(train_iters, patience, verbose) |
| |
|
| | def train_with_early_stopping(self, train_iters, patience, verbose): |
| | """early stopping based on the validation loss |
| | """ |
| | if verbose: |
| | print(f'=== training {self.name} model ===') |
| | optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) |
| |
|
| | labels = self.data.y |
| | train_mask, val_mask = self.data.train_mask, self.data.val_mask |
| |
|
| | early_stopping = patience |
| | best_loss_val = 100 |
| | best_acc_val = 0 |
| | best_epoch = 0 |
| |
|
| | x, edge_index = self.data.x, self.data.edge_index |
| | for i in range(train_iters): |
| | self.train() |
| | optimizer.zero_grad() |
| |
|
| | output = self.forward(x, edge_index) |
| |
|
| | loss_train = F.nll_loss(output[train_mask], labels[train_mask]) |
| | loss_train.backward() |
| | optimizer.step() |
| |
|
| | if verbose and i % 50 == 0: |
| | print('Epoch {}, training loss: {}'.format(i, loss_train.item())) |
| |
|
| | self.eval() |
| | output = self.forward(x, edge_index) |
| | loss_val = F.nll_loss(output[val_mask], labels[val_mask]) |
| | acc_val = utils.accuracy(output[val_mask], labels[val_mask]) |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if best_acc_val < acc_val: |
| | best_acc_val = acc_val |
| | self.output = output |
| | weights = deepcopy(self.state_dict()) |
| | patience = early_stopping |
| | best_epoch = i |
| | else: |
| | patience -= 1 |
| |
|
| | if i > early_stopping and patience <= 0: |
| | break |
| |
|
| | if verbose: |
| | |
| | print('=== early stopping at {0}, acc_val = {1} ==='.format(best_epoch, best_acc_val) ) |
| | self.load_state_dict(weights) |
| |
|
| | def test(self): |
| | """Evaluate model performance on test set. |
| | Parameters |
| | ---------- |
| | idx_test : |
| | node testing indices |
| | """ |
| | self.eval() |
| | test_mask = self.data.test_mask |
| | labels = self.data.y |
| | output = self.forward(self.data.x, self.data.edge_index) |
| | |
| | loss_test = F.nll_loss(output[test_mask], labels[test_mask]) |
| | acc_test = utils.accuracy(output[test_mask], labels[test_mask]) |
| | print("Test set results:", |
| | "loss= {:.4f}".format(loss_test.item()), |
| | "accuracy= {:.4f}".format(acc_test.item())) |
| | return acc_test.item() |
| |
|
| | def predict(self, x=None, edge_index=None, edge_weight=None): |
| | """ |
| | Returns |
| | ------- |
| | torch.FloatTensor |
| | output (log probabilities) |
| | """ |
| | self.eval() |
| | if x is None or edge_index is None: |
| | x, edge_index = self.data.x, self.data.edge_index |
| | return self.forward(x, edge_index, edge_weight) |
| |
|
| | def _ensure_contiguousness(self, |
| | x, |
| | edge_idx, |
| | edge_weight): |
| | if not x.is_sparse: |
| | x = x.contiguous() |
| | if hasattr(edge_idx, 'contiguous'): |
| | edge_idx = edge_idx.contiguous() |
| | if edge_weight is not None: |
| | edge_weight = edge_weight.contiguous() |
| | return x, edge_idx, edge_weight |
| |
|
| |
|
| |
|
| |
|