Yaning1001's picture
Add files using upload-large-folder tool
c91d7b1 verified
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from deeprobust.graph import utils
import torch
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
pass
def fit(self, pyg_data, train_iters=1000, initialize=True, verbose=False, patience=100, **kwargs):
if initialize:
self.initialize()
# self.data = pyg_data[0].to(self.device)
self.data = pyg_data.to(self.device)
# By default, it is trained with early stopping on validation
self.train_with_early_stopping(train_iters, patience, verbose)
def finetune(self, edge_index, edge_weight, feat=None, train_iters=10, verbose=True):
if verbose:
print(f'=== finetuning {self.name} model ===')
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
labels = self.data.y
if feat is None:
x = self.data.x
else:
x = feat
train_mask, val_mask = self.data.train_mask, self.data.val_mask
best_loss_val = 100
best_acc_val = 0
for i in range(train_iters):
self.train()
optimizer.zero_grad()
output = self.forward(x, edge_index, edge_weight)
loss_train = F.nll_loss(output[train_mask], labels[train_mask])
loss_train.backward()
optimizer.step()
if verbose and i % 50 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
self.eval()
with torch.no_grad():
output = self.forward(x, edge_index)
loss_val = F.nll_loss(output[val_mask], labels[val_mask])
acc_val = utils.accuracy(output[val_mask], labels[val_mask])
# if best_loss_val > loss_val:
# best_loss_val = loss_val
# best_output = output
# weights = deepcopy(self.state_dict())
if best_acc_val < acc_val:
best_acc_val = acc_val
best_output = output
weights = deepcopy(self.state_dict())
print('best_acc_val:', best_acc_val.item())
self.load_state_dict(weights)
return best_output
def _fit_with_val(self, pyg_data, train_iters=1000, initialize=True, verbose=False, **kwargs):
if initialize:
self.initialize()
# self.data = pyg_data[0].to(self.device)
self.data = pyg_data.to(self.device)
if verbose:
print(f'=== training {self.name} model ===')
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
labels = self.data.y
train_mask, val_mask = self.data.train_mask, self.data.val_mask
x, edge_index = self.data.x, self.data.edge_index
for i in range(train_iters):
self.train()
optimizer.zero_grad()
output = self.forward(x, edge_index)
loss_train = F.nll_loss(output[train_mask+val_mask], labels[train_mask+val_mask])
loss_train.backward()
optimizer.step()
if verbose and i % 50 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
def fit_with_val(self, pyg_data, train_iters=1000, initialize=True, patience=100, verbose=False, **kwargs):
if initialize:
self.initialize()
self.data = pyg_data.to(self.device)
self.data.train_mask = self.data.train_mask + self.data.val1_mask
self.data.val_mask = self.data.val2_mask
self.train_with_early_stopping(train_iters, patience, verbose)
def train_with_early_stopping(self, train_iters, patience, verbose):
"""early stopping based on the validation loss
"""
if verbose:
print(f'=== training {self.name} model ===')
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
labels = self.data.y
train_mask, val_mask = self.data.train_mask, self.data.val_mask
early_stopping = patience
best_loss_val = 100
best_acc_val = 0
best_epoch = 0
x, edge_index = self.data.x, self.data.edge_index
for i in range(train_iters):
self.train()
optimizer.zero_grad()
output = self.forward(x, edge_index)
loss_train = F.nll_loss(output[train_mask], labels[train_mask])
loss_train.backward()
optimizer.step()
if verbose and i % 50 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
self.eval()
output = self.forward(x, edge_index)
loss_val = F.nll_loss(output[val_mask], labels[val_mask])
acc_val = utils.accuracy(output[val_mask], labels[val_mask])
# print(acc)
# if best_loss_val > loss_val:
# best_loss_val = loss_val
# self.output = output
# weights = deepcopy(self.state_dict())
# patience = early_stopping
# best_epoch = i
# else:
# patience -= 1
if best_acc_val < acc_val:
best_acc_val = acc_val
self.output = output
weights = deepcopy(self.state_dict())
patience = early_stopping
best_epoch = i
else:
patience -= 1
if i > early_stopping and patience <= 0:
break
if verbose:
# print('=== early stopping at {0}, loss_val = {1} ==='.format(best_epoch, best_loss_val) )
print('=== early stopping at {0}, acc_val = {1} ==='.format(best_epoch, best_acc_val) )
self.load_state_dict(weights)
def test(self):
"""Evaluate model performance on test set.
Parameters
----------
idx_test :
node testing indices
"""
self.eval()
test_mask = self.data.test_mask
labels = self.data.y
output = self.forward(self.data.x, self.data.edge_index)
# output = self.output
loss_test = F.nll_loss(output[test_mask], labels[test_mask])
acc_test = utils.accuracy(output[test_mask], labels[test_mask])
print("Test set results:",
"loss= {:.4f}".format(loss_test.item()),
"accuracy= {:.4f}".format(acc_test.item()))
return acc_test.item()
def predict(self, x=None, edge_index=None, edge_weight=None):
"""
Returns
-------
torch.FloatTensor
output (log probabilities)
"""
self.eval()
if x is None or edge_index is None:
x, edge_index = self.data.x, self.data.edge_index
return self.forward(x, edge_index, edge_weight)
def _ensure_contiguousness(self,
x,
edge_idx,
edge_weight):
if not x.is_sparse:
x = x.contiguous()
if hasattr(edge_idx, 'contiguous'):
edge_idx = edge_idx.contiguous()
if edge_weight is not None:
edge_weight = edge_weight.contiguous()
return x, edge_idx, edge_weight