Upload 19 files
Browse files- data/baby/test.npy +3 -0
- data/baby/train.npy +3 -0
- data/baby/valid.npy +3 -0
- data/clothing/test.npy +3 -0
- data/clothing/train.npy +3 -0
- data/clothing/valid.npy +3 -0
- data/sports/test.npy +3 -0
- data/sports/train.npy +3 -0
- data/sports/valid.npy +3 -0
- main.py +114 -0
- model.py +279 -0
- trainer.py +48 -0
- utils/configurator.py +63 -0
- utils/data_loader.py +209 -0
- utils/evaluator.py +56 -0
- utils/helper.py +381 -0
- utils/logger.py +46 -0
- utils/metrics.py +55 -0
- utils/parser.py +54 -0
data/baby/test.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b890fa339aca21fac9c17c27a9f1ea163ff8df9a6e4caf353c7f7cf7d689c745
|
| 3 |
+
size 347040
|
data/baby/train.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d4c495814d61a1ea84b756f11096d468cc07113d05aecbabcef8de9cf7a1387
|
| 3 |
+
size 1896944
|
data/baby/valid.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:388bd7922d74311e6afc6fbc9f6299cc85b32310109f7d67939ef90e097babb2
|
| 3 |
+
size 329072
|
data/clothing/test.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:272ba8177a70a973842bc30eafe0d7ae7d641c5117600a0f23b24319c7ad7a61
|
| 3 |
+
size 659152
|
data/clothing/train.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:685fe396e4a8d3c6cf465a9e8421e1fa07b7d187e7fa1dc996578f11a7de896a
|
| 3 |
+
size 3157536
|
data/clothing/valid.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f8c581fa0680e5886d4781a4a551868d47138c9b2d4d1cbe886d17cef3fedcd
|
| 3 |
+
size 642528
|
data/sports/test.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9396ff247c58b918a915bc95f19574492c0cf3f8ae104b77ac2447b15f77970
|
| 3 |
+
size 640592
|
data/sports/train.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1f8c7a04efd251941083dd25bff7c154e2e2ff7bc3ccc240d45bd11f906fbf2
|
| 3 |
+
size 3494672
|
data/sports/valid.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a27b2d104aef08582fbf241dc3c77b9ce7fc0788962ec6f36eb9b8c0dafba649
|
| 3 |
+
size 606512
|
main.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import platform
|
| 4 |
+
from time import time
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from trainer import train
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from utils.parser import parse_args
|
| 9 |
+
from utils.logger import init_logger
|
| 10 |
+
from utils.configurator import Config
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tensorboardX import SummaryWriter
|
| 13 |
+
from utils.evaluator import evaluate_model
|
| 14 |
+
from utils.data_loader import Load_dataset, Load_eval_dataset
|
| 15 |
+
from utils.helper import early_stopping, plot_curve, res_output, stop_log, update_result, sele_para
|
| 16 |
+
from model import REARM
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Net:
|
| 20 |
+
def __init__(self, args):
|
| 21 |
+
# Complete initialization of all parameters (including random seeds)
|
| 22 |
+
self.config = Config(args)
|
| 23 |
+
# Use logger
|
| 24 |
+
self.logger = init_logger(self.config)
|
| 25 |
+
self.logger.info(self.config)
|
| 26 |
+
self.logger.info('██Server: \t' + platform.node())
|
| 27 |
+
self.logger.info('██Dir: \t' + os.getcwd() + '\n')
|
| 28 |
+
self.device = self.config.device
|
| 29 |
+
self.model_name = self.config.model_name
|
| 30 |
+
self.dataset_name = self.config.dataset
|
| 31 |
+
self.batch_size = self.config.batch_size
|
| 32 |
+
self.num_workers = self.config.num_workers
|
| 33 |
+
self.learning_rate = self.config.learning_rate
|
| 34 |
+
self.num_epoch = self.config.num_epoch
|
| 35 |
+
self.topk = self.config.topk
|
| 36 |
+
self.metrics = self.config.metrics
|
| 37 |
+
self.valid_metric = self.config.valid_metric
|
| 38 |
+
self.stopping_step = self.config.stopping_step
|
| 39 |
+
self.reg_weight = self.config.reg_weight
|
| 40 |
+
self.cur_step = 0
|
| 41 |
+
self.best_valid_score = -1
|
| 42 |
+
self.best_valid_result = {}
|
| 43 |
+
self.best_test_upon_valid = {}
|
| 44 |
+
# Writer will output to ./runs/ directory by default
|
| 45 |
+
self.writer = SummaryWriter() if self.config.writer else None
|
| 46 |
+
|
| 47 |
+
# Perform experimental configurations
|
| 48 |
+
Dataset = Load_dataset(self.config)
|
| 49 |
+
valid_dataset, test_dataset = Dataset.load_eval_data()
|
| 50 |
+
self.train_data = DataLoader(Dataset, batch_size=self.batch_size, shuffle=True,
|
| 51 |
+
num_workers=self.num_workers)
|
| 52 |
+
|
| 53 |
+
(self.valid_data, self.test_data) = (Load_eval_dataset("Validation", self.config, valid_dataset),
|
| 54 |
+
Load_eval_dataset("Testing", self.config, test_dataset))
|
| 55 |
+
self.model = REARM(self.config, Dataset).to(self.device)
|
| 56 |
+
self.optimizer = optim.AdamW(self.model.parameters(), self.learning_rate, weight_decay=self.reg_weight)
|
| 57 |
+
lr_scheduler = self.config.learning_rate_scheduler
|
| 58 |
+
fac = lambda epoch: lr_scheduler[0] ** (epoch / lr_scheduler[1])
|
| 59 |
+
scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac)
|
| 60 |
+
self.lr_scheduler = scheduler
|
| 61 |
+
self.logger.info(self.model)
|
| 62 |
+
|
| 63 |
+
def plot_train_loss(self):
|
| 64 |
+
plot_curve(self)
|
| 65 |
+
|
| 66 |
+
def run(self):
|
| 67 |
+
run_start_time = time()
|
| 68 |
+
for epoch_idx in tqdm(range(self.num_epoch)):
|
| 69 |
+
train_start_time = time()
|
| 70 |
+
train_loss = train(self, epoch_idx)
|
| 71 |
+
# Save if an exception occurs
|
| 72 |
+
if torch.isnan(train_loss[0]):
|
| 73 |
+
ret_value = {"Recall@20": -1} if self.best_test_upon_valid == {} else self.best_test_upon_valid
|
| 74 |
+
stop_output = '\n ' + str(self.config.dataset) + ' key parameter: ' + sele_para(self.config)
|
| 75 |
+
self.logger.info(stop_output)
|
| 76 |
+
self.logger.info('Loss is nan at epoch: {}; last value is {}Exiting.'.format(epoch_idx, ret_value))
|
| 77 |
+
return ret_value
|
| 78 |
+
|
| 79 |
+
self.lr_scheduler.step()
|
| 80 |
+
|
| 81 |
+
train_output = res_output(epoch_idx, train_start_time, time(), train_loss, "train")
|
| 82 |
+
self.logger.info(train_output)
|
| 83 |
+
|
| 84 |
+
# valid evaluate_model
|
| 85 |
+
valid_start_time = time()
|
| 86 |
+
valid_score, valid_result = evaluate_model(self, epoch_idx, self.valid_data, t_or_v="valid")
|
| 87 |
+
|
| 88 |
+
self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
|
| 89 |
+
valid_score, self.best_valid_score, self.cur_step, self.stopping_step)
|
| 90 |
+
|
| 91 |
+
self.best_valid_result[epoch_idx] = self.best_valid_score
|
| 92 |
+
valid_output = res_output(epoch_idx, valid_start_time, time(), valid_result, t_or_v="valid")
|
| 93 |
+
self.logger.info(valid_output)
|
| 94 |
+
|
| 95 |
+
if update_flag:
|
| 96 |
+
# test evaluate_model
|
| 97 |
+
test_start_time = time()
|
| 98 |
+
_, test_result = evaluate_model(self, epoch_idx, self.test_data, t_or_v="test")
|
| 99 |
+
test_score_output = res_output(epoch_idx, test_start_time, time(), test_result, t_or_v="test")
|
| 100 |
+
self.logger.info(test_score_output)
|
| 101 |
+
update_result(self, test_result)
|
| 102 |
+
|
| 103 |
+
if stop_flag:
|
| 104 |
+
stop_log(self, epoch_idx, run_start_time)
|
| 105 |
+
break
|
| 106 |
+
else:
|
| 107 |
+
print('patience ==> %d' % (self.stopping_step - self.cur_step))
|
| 108 |
+
return self.best_test_upon_valid
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == '__main__':
|
| 112 |
+
_args = parse_args()
|
| 113 |
+
model = Net(_args)
|
| 114 |
+
best_score = model.run()
|
model.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from utils.helper import get_norm_adj_mat, ssl_loss, topk_sample, cal_diff_loss, propgt_info
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class REARM(nn.Module):
|
| 8 |
+
def __init__(self, config, dataset):
|
| 9 |
+
super(REARM, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.n_users = dataset.n_users
|
| 12 |
+
self.n_items = dataset.n_items
|
| 13 |
+
self.n_nodes = self.n_users + self.n_items
|
| 14 |
+
self.i_v_feat = dataset.i_v_feat
|
| 15 |
+
self.i_t_feat = dataset.i_t_feat
|
| 16 |
+
self.embedding_dim = config.embedding_dim
|
| 17 |
+
self.feat_embed_dim = config.embedding_dim
|
| 18 |
+
self.dim_feat = self.feat_embed_dim
|
| 19 |
+
self.reg_weight = config.reg_weight
|
| 20 |
+
self.device = config.device
|
| 21 |
+
self.cl_tmp = config.cl_tmp
|
| 22 |
+
self.cl_loss_weight = config.cl_loss_weight
|
| 23 |
+
self.diff_loss_weight = config.diff_loss_weight
|
| 24 |
+
self.n_layers = config.n_layers
|
| 25 |
+
self.num_user_co = config.num_user_co
|
| 26 |
+
self.num_item_co = config.num_item_co
|
| 27 |
+
self.user_aggr_mode = config.user_aggr_mode
|
| 28 |
+
self.n_ii_layers = config.n_ii_layers
|
| 29 |
+
self.n_uu_layers = config.n_uu_layers
|
| 30 |
+
self.k = config.rank
|
| 31 |
+
self.uu_co_weight = config.uu_co_weight
|
| 32 |
+
self.ii_co_weight = config.ii_co_weight
|
| 33 |
+
|
| 34 |
+
# Load user and item graphs
|
| 35 |
+
self.topK_users = dataset.topK_users
|
| 36 |
+
self.topK_items = dataset.topK_items
|
| 37 |
+
self.dict_user_co_occ_graph = dataset.dict_user_co_occ_graph
|
| 38 |
+
self.dict_item_co_occ_graph = dataset.dict_item_co_occ_graph
|
| 39 |
+
self.topK_users_counts = dataset.topK_users_counts
|
| 40 |
+
self.topK_items_counts = dataset.topK_items_counts
|
| 41 |
+
|
| 42 |
+
self.s_drop = config.s_drop
|
| 43 |
+
self.m_drop = config.m_drop
|
| 44 |
+
self.ly_norm = nn.LayerNorm(self.feat_embed_dim)
|
| 45 |
+
|
| 46 |
+
self.self_i_attn1 = nn.MultiheadAttention(1, 1, dropout=self.s_drop, batch_first=True)
|
| 47 |
+
self.self_i_attn2 = nn.MultiheadAttention(1, 1, dropout=self.s_drop, batch_first=True)
|
| 48 |
+
|
| 49 |
+
self.mutual_i_attn1 = nn.MultiheadAttention(1, 1, dropout=self.m_drop, batch_first=True)
|
| 50 |
+
self.mutual_i_attn2 = nn.MultiheadAttention(1, 1, dropout=self.m_drop, batch_first=True)
|
| 51 |
+
|
| 52 |
+
self.user_id_embedding = nn.Embedding(self.n_users, self.embedding_dim).to(self.device)
|
| 53 |
+
self.item_id_embedding = nn.Embedding(self.n_items, self.embedding_dim).to(self.device)
|
| 54 |
+
|
| 55 |
+
self.prl = nn.PReLU().to(self.device)
|
| 56 |
+
|
| 57 |
+
self.cal_bpr = torch.tensor([[1.0], [-1.0]]).to(self.device)
|
| 58 |
+
|
| 59 |
+
# load dataset info
|
| 60 |
+
self.norm_adj = get_norm_adj_mat(self, dataset.sparse_inter_matrix(form='coo')).to(self.device)
|
| 61 |
+
# Process to obtain user co-occurrence matrix (n_users*num_user_co)
|
| 62 |
+
self.user_co_graph = topk_sample(self.n_users, self.dict_user_co_occ_graph, self.num_user_co,
|
| 63 |
+
self.topK_users, self.topK_users_counts, 'softmax',
|
| 64 |
+
self.device)
|
| 65 |
+
|
| 66 |
+
# Process to obtain user co-occurrence matrix (n_users*num_user_co)
|
| 67 |
+
self.item_co_graph = topk_sample(self.n_items, self.dict_item_co_occ_graph, self.num_item_co,
|
| 68 |
+
self.topK_items, self.topK_items_counts, 'softmax',
|
| 69 |
+
self.device)
|
| 70 |
+
|
| 71 |
+
# Process to obtain item similarity matrix (n_items* n_items )
|
| 72 |
+
self.i_mm_adj = dataset.i_mm_adj
|
| 73 |
+
# Process to obtain user similarity matrix (n_users* n_users)
|
| 74 |
+
self.u_mm_adj = dataset.u_mm_adj
|
| 75 |
+
|
| 76 |
+
# Strengthen ii and uu graphs
|
| 77 |
+
self.stre_ii_graph = self.ii_co_weight * self.item_co_graph + (1.0 - self.ii_co_weight) * self.i_mm_adj
|
| 78 |
+
self.stre_uu_graph = self.uu_co_weight * self.user_co_graph + (1.0 - self.uu_co_weight) * self.u_mm_adj
|
| 79 |
+
|
| 80 |
+
if self.i_v_feat is not None:
|
| 81 |
+
self.image_embedding = nn.Embedding.from_pretrained(self.i_v_feat, freeze=False).to(self.device)
|
| 82 |
+
self.image_i_trs = nn.Linear(self.i_v_feat.shape[1], self.feat_embed_dim)
|
| 83 |
+
self.user_v_prefer = torch.nn.Parameter(dataset.u_v_interest, requires_grad=True).to(self.device)
|
| 84 |
+
self.image_u_trs = nn.Linear(self.i_v_feat.shape[1], self.feat_embed_dim)
|
| 85 |
+
|
| 86 |
+
if self.i_t_feat is not None:
|
| 87 |
+
self.text_embedding = nn.Embedding.from_pretrained(self.i_t_feat, freeze=False).to(self.device)
|
| 88 |
+
self.text_i_trs = nn.Linear(self.i_t_feat.shape[1], self.feat_embed_dim)
|
| 89 |
+
self.user_t_prefer = torch.nn.Parameter(dataset.u_t_interest, requires_grad=True).to(self.device)
|
| 90 |
+
self.text_u_trs = nn.Linear(self.i_t_feat.shape[1], self.feat_embed_dim)
|
| 91 |
+
|
| 92 |
+
# MLP(input_dim, feature_dim, hidden_dim, output_dim)
|
| 93 |
+
self.mlp_u1 = MLP(self.feat_embed_dim, self.feat_embed_dim * self.k, self.feat_embed_dim * self.k, self.device)
|
| 94 |
+
self.mlp_u2 = MLP(self.feat_embed_dim, self.feat_embed_dim * self.k, self.feat_embed_dim * self.k, self.device)
|
| 95 |
+
self.mlp_i1 = MLP(self.feat_embed_dim, self.feat_embed_dim * self.k, self.feat_embed_dim * self.k, self.device)
|
| 96 |
+
self.mlp_i2 = MLP(self.feat_embed_dim, self.feat_embed_dim * self.k, self.feat_embed_dim * self.k, self.device)
|
| 97 |
+
self.meta_netu = nn.Linear(self.feat_embed_dim * 2, self.feat_embed_dim, bias=True) # Knowledge compression
|
| 98 |
+
self.meta_neti = nn.Linear(self.feat_embed_dim * 2, self.feat_embed_dim, bias=True) # Knowledge compression
|
| 99 |
+
|
| 100 |
+
self._reset_parameters()
|
| 101 |
+
|
| 102 |
+
def _reset_parameters(self):
|
| 103 |
+
nn.init.normal_(self.user_id_embedding.weight, std=0.1)
|
| 104 |
+
nn.init.normal_(self.item_id_embedding.weight, std=0.1)
|
| 105 |
+
|
| 106 |
+
nn.init.xavier_normal_(self.image_i_trs.weight)
|
| 107 |
+
nn.init.xavier_normal_(self.text_i_trs.weight)
|
| 108 |
+
nn.init.xavier_normal_(self.image_u_trs.weight)
|
| 109 |
+
nn.init.xavier_normal_(self.text_u_trs.weight)
|
| 110 |
+
|
| 111 |
+
def forward(self):
|
| 112 |
+
# Uniform feature dimensions for multi-modal feature information on item
|
| 113 |
+
trs_item_v_feat = self.image_i_trs(self.image_embedding.weight)
|
| 114 |
+
trs_item_t_feat = self.text_i_trs(self.text_embedding.weight) # num_items * 64
|
| 115 |
+
|
| 116 |
+
trs_user_v_prefer = self.image_u_trs(self.user_v_prefer)
|
| 117 |
+
trs_user_t_prefer = self.text_u_trs(self.user_t_prefer) # num_items * 64
|
| 118 |
+
|
| 119 |
+
# ====================================================================================
|
| 120 |
+
# Homography Relation Learning
|
| 121 |
+
# ====================================================================================
|
| 122 |
+
# Item homogeneous relational learning
|
| 123 |
+
item_v_t = torch.cat((trs_item_v_feat, trs_item_t_feat), dim=-1)
|
| 124 |
+
item_id_v_t = torch.cat((self.item_id_embedding.weight, item_v_t), dim=-1)
|
| 125 |
+
item_id_v_t = propgt_info(item_id_v_t, self.n_ii_layers, self.stre_ii_graph, last_layer=True)
|
| 126 |
+
item_id_v_t = F.normalize(item_id_v_t)
|
| 127 |
+
|
| 128 |
+
item_id_ii = item_id_v_t[:, :self.embedding_dim]
|
| 129 |
+
gnn_i_v_feat = item_id_v_t[:, self.feat_embed_dim:-self.feat_embed_dim]
|
| 130 |
+
gnn_i_t_feat = item_id_v_t[:, -self.feat_embed_dim:]
|
| 131 |
+
|
| 132 |
+
# User homogeneous relational learning
|
| 133 |
+
user_v_t = torch.cat((trs_user_v_prefer, trs_user_t_prefer), dim=-1)
|
| 134 |
+
user_id_v_t = torch.cat((self.user_id_embedding.weight, user_v_t), dim=-1)
|
| 135 |
+
user_id_v_t = propgt_info(user_id_v_t, self.n_uu_layers, self.stre_uu_graph, last_layer=True)
|
| 136 |
+
|
| 137 |
+
user_id_v_t = F.normalize(user_id_v_t)
|
| 138 |
+
user_id_uu = user_id_v_t[:, :self.embedding_dim]
|
| 139 |
+
gnn_u_v_prefer = user_id_v_t[:, self.embedding_dim:-self.feat_embed_dim]
|
| 140 |
+
gnn_u_t_prefer = user_id_v_t[:, -self.feat_embed_dim:]
|
| 141 |
+
|
| 142 |
+
# ====================================================================================
|
| 143 |
+
# Item Feature Attention Integration
|
| 144 |
+
# ====================================================================================
|
| 145 |
+
# Item visual features self-attention
|
| 146 |
+
item_v_feat, _ = self.self_i_attn1(gnn_i_v_feat.unsqueeze(2), gnn_i_v_feat.unsqueeze(2),
|
| 147 |
+
gnn_i_v_feat.unsqueeze(2), need_weights=False)
|
| 148 |
+
item_v_feat = self.ly_norm(gnn_i_v_feat + item_v_feat.squeeze())
|
| 149 |
+
item_v_feat = self.prl(item_v_feat)
|
| 150 |
+
|
| 151 |
+
# Item text features self-attention
|
| 152 |
+
item_t_feat, _ = self.self_i_attn2(gnn_i_t_feat.unsqueeze(2), gnn_i_t_feat.unsqueeze(2),
|
| 153 |
+
gnn_i_t_feat.unsqueeze(2), need_weights=False)
|
| 154 |
+
item_t_feat = self.ly_norm(gnn_i_t_feat + item_t_feat.squeeze())
|
| 155 |
+
item_t_feat = self.prl(item_t_feat)
|
| 156 |
+
|
| 157 |
+
# ---------------------------------------------------------------------------------------
|
| 158 |
+
# Item text to visual cross-attention
|
| 159 |
+
i_t2v_feat, _ = self.mutual_i_attn1(item_t_feat.unsqueeze(2), item_v_feat.unsqueeze(2),
|
| 160 |
+
item_v_feat.unsqueeze(2), need_weights=False)
|
| 161 |
+
item_t2v_feat = self.ly_norm(item_v_feat + i_t2v_feat.squeeze())
|
| 162 |
+
item_t2v_feat = self.prl(item_t2v_feat)
|
| 163 |
+
|
| 164 |
+
# Item visual to text cross-attention
|
| 165 |
+
i_v2t_feat, _ = self.mutual_i_attn2(item_v_feat.unsqueeze(2), item_t_feat.unsqueeze(2),
|
| 166 |
+
item_t_feat.unsqueeze(2), need_weights=False)
|
| 167 |
+
item_v2t_feat = self.ly_norm(item_t_feat.squeeze() + i_v2t_feat.squeeze())
|
| 168 |
+
item_v2t_feat = self.prl(item_v2t_feat)
|
| 169 |
+
|
| 170 |
+
user_v_prefer = self.prl(gnn_u_v_prefer) # (num_items* 64)
|
| 171 |
+
user_t_prefer = self.prl(gnn_u_t_prefer)
|
| 172 |
+
|
| 173 |
+
# ====================================================================================
|
| 174 |
+
# Heterography Relation Learning
|
| 175 |
+
# ====================================================================================
|
| 176 |
+
# Item feature splicing with total attentions
|
| 177 |
+
item_v_t_feat = torch.cat((item_t2v_feat, item_v2t_feat), dim=-1) # (num_items* 128)
|
| 178 |
+
user_v_t_prefer = torch.cat((user_v_prefer, user_t_prefer), dim=-1) # (num_user* 128)
|
| 179 |
+
ego_feat_prefer = torch.cat((user_v_t_prefer, item_v_t_feat), dim=0) # (num_users+num_items)* 128)
|
| 180 |
+
self.fin_feat_prefer = propgt_info(ego_feat_prefer, self.n_layers, self.norm_adj)
|
| 181 |
+
|
| 182 |
+
ego_id_embed = torch.cat((user_id_uu, item_id_ii), dim=0) # (num_users+num_items)* 64)
|
| 183 |
+
fin_id_embed = propgt_info(ego_id_embed, self.n_layers, self.norm_adj)
|
| 184 |
+
|
| 185 |
+
share_knowldge = self.meta_extra_share(fin_id_embed, self.fin_feat_prefer) # (num_users+num_items)* 64)
|
| 186 |
+
|
| 187 |
+
fin_v = self.prl(self.fin_feat_prefer[:, :self.embedding_dim]) + fin_id_embed
|
| 188 |
+
fin_t = self.prl(self.fin_feat_prefer[:, self.embedding_dim:]) + fin_id_embed
|
| 189 |
+
fin_share = self.prl(share_knowldge) + fin_id_embed
|
| 190 |
+
|
| 191 |
+
temp_full_feat_prefer = torch.cat((fin_v, fin_t), dim=-1)
|
| 192 |
+
representation = torch.cat((temp_full_feat_prefer, fin_share), dim=-1)
|
| 193 |
+
|
| 194 |
+
return representation
|
| 195 |
+
|
| 196 |
+
def loss(self, user_tensor, item_tensor):
|
| 197 |
+
user_tensor_flatten = user_tensor.view(-1)
|
| 198 |
+
item_tensor_flatten = item_tensor.view(-1)
|
| 199 |
+
out = self.forward()
|
| 200 |
+
user_rep = out[user_tensor_flatten]
|
| 201 |
+
item_rep = out[item_tensor_flatten]
|
| 202 |
+
|
| 203 |
+
score = torch.sum(user_rep * item_rep, dim=1).view(-1, 2)
|
| 204 |
+
bpr_score = torch.matmul(score, self.cal_bpr)
|
| 205 |
+
bpr_loss = -torch.mean(nn.LogSigmoid()(bpr_score))
|
| 206 |
+
|
| 207 |
+
# Loss of multi-modal feature contrasts
|
| 208 |
+
i_mul_vt_cl_loss = ssl_loss(self.fin_feat_prefer[:, :self.feat_embed_dim],
|
| 209 |
+
self.fin_feat_prefer[:, -self.feat_embed_dim:], item_tensor_flatten, self.cl_tmp)
|
| 210 |
+
u_mul_vt_cl_loss = ssl_loss(self.fin_feat_prefer[:, :self.feat_embed_dim],
|
| 211 |
+
self.fin_feat_prefer[:, -self.feat_embed_dim:], user_tensor_flatten, self.cl_tmp)
|
| 212 |
+
mul_vt_cl_loss = self.cl_loss_weight * (i_mul_vt_cl_loss + u_mul_vt_cl_loss)
|
| 213 |
+
|
| 214 |
+
# Modal-unique orthogonal constraint
|
| 215 |
+
mul_i_diff_loss = cal_diff_loss(self.fin_feat_prefer, user_tensor, self.feat_embed_dim)
|
| 216 |
+
mul_u_diff_loss = cal_diff_loss(self.fin_feat_prefer, item_tensor, self.feat_embed_dim)
|
| 217 |
+
mul_diff_loss = self.diff_loss_weight * (mul_i_diff_loss + mul_u_diff_loss)
|
| 218 |
+
|
| 219 |
+
reg_loss = 0 # Realized in AdamW
|
| 220 |
+
total_loss = bpr_loss + reg_loss + mul_vt_cl_loss + mul_diff_loss
|
| 221 |
+
|
| 222 |
+
return total_loss, bpr_loss, reg_loss, mul_vt_cl_loss, mul_diff_loss
|
| 223 |
+
|
| 224 |
+
def full_sort_predict(self, interaction):
|
| 225 |
+
user = interaction[0]
|
| 226 |
+
representation = self.forward()
|
| 227 |
+
u_reps, i_reps = torch.split(representation, [self.n_users, self.n_items], dim=0)
|
| 228 |
+
score_mat_ui = torch.matmul(u_reps[user], i_reps.t())
|
| 229 |
+
return score_mat_ui
|
| 230 |
+
|
| 231 |
+
def meta_extra_share(self, id_embed, prefer_or_feat):
|
| 232 |
+
u_id_embed = id_embed[:self.n_users, :]
|
| 233 |
+
i_id_embed = id_embed[self.n_users:, :]
|
| 234 |
+
|
| 235 |
+
u_v_t = prefer_or_feat[:self.n_users, :]
|
| 236 |
+
i_v_t = prefer_or_feat[self.n_users:, :]
|
| 237 |
+
|
| 238 |
+
# meta-knowlege extraction
|
| 239 |
+
u_knowldge = self.meta_netu(u_v_t).detach()
|
| 240 |
+
i_knowldge = self.meta_neti(i_v_t).detach()
|
| 241 |
+
|
| 242 |
+
""" Personalized transformation parameter matrix """
|
| 243 |
+
# Low rank matrix decomposition
|
| 244 |
+
metau1 = self.mlp_u1(u_knowldge).reshape(-1, self.feat_embed_dim, self.k) # N_u*d*k [19445, 64, 3]
|
| 245 |
+
metau2 = self.mlp_u2(u_knowldge).reshape(-1, self.k, self.feat_embed_dim) # N_u*k*d [19445, 3, 64]
|
| 246 |
+
metai1 = self.mlp_i1(i_knowldge).reshape(-1, self.feat_embed_dim, self.k) # N_i*d*k [7050, 64, 3]
|
| 247 |
+
metai2 = self.mlp_i2(i_knowldge).reshape(-1, self.k, self.feat_embed_dim) # N_i*k*d [7050, 3,64]
|
| 248 |
+
meta_biasu = torch.mean(metau1, dim=0) # d*k [64, 3]
|
| 249 |
+
meta_biasu1 = torch.mean(metau2, dim=0) # k*d [3,64]
|
| 250 |
+
meta_biasi = torch.mean(metai1, dim=0) # [64, 3]
|
| 251 |
+
meta_biasi1 = torch.mean(metai2, dim=0) # [3, 64]
|
| 252 |
+
low_weightu1 = F.softmax(metau1 + meta_biasu, dim=1)
|
| 253 |
+
low_weightu2 = F.softmax(metau2 + meta_biasu1, dim=1)
|
| 254 |
+
low_weighti1 = F.softmax(metai1 + meta_biasi, dim=1)
|
| 255 |
+
low_weighti2 = F.softmax(metai2 + meta_biasi1, dim=1)
|
| 256 |
+
|
| 257 |
+
# The learned matrix as the weights of the transformed network Equal to a two-layer linear network;
|
| 258 |
+
u_middle_knowldge = torch.sum(torch.multiply(u_id_embed.unsqueeze(-1), low_weightu1), dim=1)
|
| 259 |
+
u_share_knowldge = torch.sum(torch.multiply(u_middle_knowldge.unsqueeze(-1), low_weightu2), dim=1)
|
| 260 |
+
i_middle_knowldge = torch.sum(torch.multiply(i_id_embed.unsqueeze(-1), low_weighti1), dim=1)
|
| 261 |
+
i_share_knowldge = torch.sum(torch.multiply(i_middle_knowldge.unsqueeze(-1), low_weighti2), dim=1)
|
| 262 |
+
|
| 263 |
+
share_knowldge = torch.cat((u_share_knowldge, i_share_knowldge), dim=0)
|
| 264 |
+
return share_knowldge
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class MLP(torch.nn.Module):
|
| 268 |
+
def __init__(self, input_dim, feature_dim, output_dim, device):
|
| 269 |
+
super(MLP, self).__init__()
|
| 270 |
+
self.device = device
|
| 271 |
+
self.linear_pre = nn.Linear(input_dim, feature_dim, bias=True)
|
| 272 |
+
self.prl = nn.PReLU().to(self.device)
|
| 273 |
+
self.linear_out = nn.Linear(feature_dim, output_dim, bias=True)
|
| 274 |
+
|
| 275 |
+
def forward(self, data):
|
| 276 |
+
x = self.prl(self.linear_pre(data))
|
| 277 |
+
x = self.linear_out(x)
|
| 278 |
+
x = F.normalize(x, p=2, dim=-1)
|
| 279 |
+
return x
|
trainer.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from torch.nn.utils.clip_grad import clip_grad_norm_
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# def train(length, epoch, dataloader, model, optimizer, batch_size, writer=None):
|
| 8 |
+
def train(self, epoch_idx):
|
| 9 |
+
self.model.train()
|
| 10 |
+
sum_loss = 0.0
|
| 11 |
+
sum_bpr_loss, sum_reg_loss = 0.0, 0.0
|
| 12 |
+
sum_diff_loss, sum_mul_vt_cl_loss = 0.0, 0.0
|
| 13 |
+
|
| 14 |
+
step = 0.0
|
| 15 |
+
# bar = tqdm(total=len(self.train_dataset))
|
| 16 |
+
# num_bar = 0 self_vt_cl_loss, mul_vt_cl_loss
|
| 17 |
+
for batch_idx, interactions in enumerate(self.train_data):
|
| 18 |
+
self.optimizer.zero_grad()
|
| 19 |
+
loss, bpr_loss, reg_loss, mul_vt_cl_loss, diff_loss = self.model.loss(interactions[0],
|
| 20 |
+
interactions[1])
|
| 21 |
+
if torch.isnan(loss):
|
| 22 |
+
self.logger.info('Loss is nan at epoch: {}, batch index: {}. Exiting.'.format(epoch_idx, batch_idx))
|
| 23 |
+
return loss, torch.tensor(0.0)
|
| 24 |
+
|
| 25 |
+
loss.backward()
|
| 26 |
+
self.optimizer.step()
|
| 27 |
+
|
| 28 |
+
step += 1.0
|
| 29 |
+
sum_loss += loss
|
| 30 |
+
sum_bpr_loss += bpr_loss
|
| 31 |
+
sum_reg_loss += reg_loss
|
| 32 |
+
sum_mul_vt_cl_loss += mul_vt_cl_loss
|
| 33 |
+
sum_diff_loss += diff_loss
|
| 34 |
+
mean_loss = sum_loss / step
|
| 35 |
+
mean_bpr_loss = sum_bpr_loss / step
|
| 36 |
+
mean_reg_loss = sum_reg_loss / step
|
| 37 |
+
mean_mul_vt_cl_loss = sum_mul_vt_cl_loss / step
|
| 38 |
+
mean_diff_loss = sum_diff_loss / step
|
| 39 |
+
|
| 40 |
+
if self.writer is not None:
|
| 41 |
+
self.writer.add_scalar('loss/train', mean_loss, epoch_idx)
|
| 42 |
+
self.writer.add_scalar('loss/bpr_loss', mean_bpr_loss, epoch_idx)
|
| 43 |
+
self.writer.add_scalar('loss/reg_loss', mean_reg_loss, epoch_idx)
|
| 44 |
+
self.writer.add_scalar('loss/mul_vt_cl_loss', mean_mul_vt_cl_loss, epoch_idx)
|
| 45 |
+
self.writer.add_scalar('loss/diff_loss', mean_diff_loss, epoch_idx)
|
| 46 |
+
|
| 47 |
+
# bar.close()
|
| 48 |
+
return [mean_loss, mean_bpr_loss, mean_reg_loss, mean_mul_vt_cl_loss, mean_diff_loss]
|
utils/configurator.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import multiprocessing
|
| 4 |
+
from utils.helper import init_seed
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Config(object):
|
| 8 |
+
def __init__(self, args):
|
| 9 |
+
self.model_name = args.model_name
|
| 10 |
+
self.dataset = args.dataset
|
| 11 |
+
self.learning_rate = args.l_r
|
| 12 |
+
self.learning_rate_scheduler = args.learning_rate_scheduler
|
| 13 |
+
self.embedding_dim = args.embedding_dim
|
| 14 |
+
self.num_epoch = args.num_epoch
|
| 15 |
+
self.reg_weight = args.reg_weight
|
| 16 |
+
self.use_gpu = args.use_gpu
|
| 17 |
+
self.gpu_id = args.gpu_id
|
| 18 |
+
self.seed = args.seed
|
| 19 |
+
|
| 20 |
+
self.batch_size = args.batch_size
|
| 21 |
+
self.eval_batch_size = args.eval_batch_size
|
| 22 |
+
self.topk = args.topk
|
| 23 |
+
self.valid_metric = args.valid_metric
|
| 24 |
+
self.metrics = args.metrics
|
| 25 |
+
self.stopping_step = args.stopping_step
|
| 26 |
+
self.n_layers = args.num_layer
|
| 27 |
+
|
| 28 |
+
self.rank = args.rank
|
| 29 |
+
self.s_drop = args.s_drop
|
| 30 |
+
self.m_drop = args.m_drop
|
| 31 |
+
self.cl_tmp = args.cl_tmp
|
| 32 |
+
self.item_knn_k = args.item_knn_k
|
| 33 |
+
self.user_knn_k = args.user_knn_k
|
| 34 |
+
self.num_user_co = args.user_knn_k # same as user_knn_k to compute
|
| 35 |
+
self.num_item_co = args.item_knn_k # same as item_knn_k to compute
|
| 36 |
+
self.n_ii_layers = args.n_ii_layers
|
| 37 |
+
self.n_uu_layers = args.n_uu_layers
|
| 38 |
+
self.writer = args.with_tensorboard
|
| 39 |
+
self.uu_co_weight = args.uu_co_weight
|
| 40 |
+
self.ii_co_weight = args.ii_co_weight
|
| 41 |
+
self.cl_loss_weight = args.cl_loss_weight
|
| 42 |
+
self.user_aggr_mode = args.user_aggr_mode
|
| 43 |
+
self.i_mm_image_weight = args.i_mm_image_weight
|
| 44 |
+
self.u_mm_image_weight = args.u_mm_image_weight
|
| 45 |
+
self.diff_loss_weight = args.diff_loss_weight
|
| 46 |
+
|
| 47 |
+
self._init_device(args)
|
| 48 |
+
init_seed(self.seed)
|
| 49 |
+
|
| 50 |
+
def _init_device(self, args):
|
| 51 |
+
if self.use_gpu:
|
| 52 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
|
| 53 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() and self.use_gpu else "cpu")
|
| 54 |
+
|
| 55 |
+
# Ensure that setting up multiple threads does not exceed
|
| 56 |
+
max_cpu_count = multiprocessing.cpu_count()
|
| 57 |
+
self.num_workers = max_cpu_count // 2 if max_cpu_count // 2 < args.num_workers else args.num_workers
|
| 58 |
+
|
| 59 |
+
def __str__(self):
|
| 60 |
+
args_info = '\nModel arguments: '
|
| 61 |
+
args_info += ',\n'.join(["{} = {}".format(arg, value) for arg, value in self.__dict__.items()])
|
| 62 |
+
args_info += '.\n'
|
| 63 |
+
return args_info
|
utils/data_loader.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from utils.helper import *
|
| 3 |
+
from logging import getLogger
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import torch.sparse as tsp
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseDataset(object):
|
| 9 |
+
def __init__(self, config):
|
| 10 |
+
self.config = config
|
| 11 |
+
self.logger = getLogger("normal")
|
| 12 |
+
self.device = config.device
|
| 13 |
+
self.dataset_name = config.dataset
|
| 14 |
+
self.load_all_data()
|
| 15 |
+
self.processed_eval_data()
|
| 16 |
+
self.n_users = len(set(self.train_data[:, 0]) | set(self.valid_data[:, 0]) | set(self.test_data[:, 0]))
|
| 17 |
+
self.n_items = len(set(self.train_data[:, 1]) | set(self.valid_data[:, 1]) | set(self.test_data[:, 1]))
|
| 18 |
+
self.train_data[:, 1] += self.n_users # Ensure that the ids are different
|
| 19 |
+
self.valid_data[:, 1] += self.n_users # Ensure that the ids are different
|
| 20 |
+
self.test_data[:, 1] += self.n_users # Ensure that the ids are different
|
| 21 |
+
self.dict_user_items()
|
| 22 |
+
|
| 23 |
+
def load_all_data(self):
|
| 24 |
+
dataset_path = str('./data/' + self.dataset_name)
|
| 25 |
+
self.train_dataset = np.load(dataset_path + '/train.npy', allow_pickle=True) # [[1,2,3],[2,3,0]]
|
| 26 |
+
v_feat = np.load(dataset_path + '/image_feat.npy', allow_pickle=True)
|
| 27 |
+
self.i_v_feat = torch.from_numpy(v_feat).type(torch.FloatTensor).to(self.device) # 4096
|
| 28 |
+
t_feat = np.load(dataset_path + '/text_feat.npy', allow_pickle=True)
|
| 29 |
+
self.i_t_feat = torch.from_numpy(t_feat).type(torch.FloatTensor).to(self.device) # 384
|
| 30 |
+
self.valid_dataset = np.load(dataset_path + '/valid.npy', allow_pickle=True) # [[1,2,3],[2,3,0]]
|
| 31 |
+
self.test_dataset = np.load(dataset_path + '/test.npy', allow_pickle=True) # [[1,2,3],[2,3,0]]
|
| 32 |
+
|
| 33 |
+
def processed_eval_data(self):
|
| 34 |
+
self.train_data = self.train_dataset.transpose(1, 0).copy()
|
| 35 |
+
self.valid_data = self.valid_dataset.transpose(1, 0).copy()
|
| 36 |
+
self.test_data = self.test_dataset.transpose(1, 0).copy()
|
| 37 |
+
|
| 38 |
+
def load_eval_data(self):
|
| 39 |
+
return self.valid_data, self.test_data
|
| 40 |
+
|
| 41 |
+
def dict_user_items(self):
|
| 42 |
+
self.dict_train_u_i = update_dict("user", self.train_data, defaultdict(set))
|
| 43 |
+
self.dict_train_i_u = update_dict("item", self.train_data, defaultdict(set))
|
| 44 |
+
tmp_dict_u_i = update_dict("user", self.valid_data, self.dict_train_u_i)
|
| 45 |
+
self.user_items_dict = update_dict("user", self.test_data, tmp_dict_u_i)
|
| 46 |
+
|
| 47 |
+
# Process out the most interacted users
|
| 48 |
+
# (first sort by the number of users interacting with the user, and finally return the user values in descending order)
|
| 49 |
+
sort_itme_num = sorted(self.dict_train_u_i.items(), key=lambda item: len(item[1]), reverse=True)
|
| 50 |
+
self.topK_users = [temp[0] for temp in sort_itme_num]
|
| 51 |
+
self.topK_users_counts = [len(temp[1]) for temp in sort_itme_num]
|
| 52 |
+
# Process out the most interacted items
|
| 53 |
+
# (first sort by the number of users interacting with the item, and finally return the item values in descending order)
|
| 54 |
+
sort_user_num = sorted(self.dict_train_i_u.items(), key=lambda item: len(item[1]), reverse=True)
|
| 55 |
+
self.topK_items = [temp[0] - self.n_users for temp in sort_user_num] # Guaranteed from 0
|
| 56 |
+
self.topK_items_counts = [len(temp[1]) for temp in sort_user_num]
|
| 57 |
+
|
| 58 |
+
def sparse_inter_matrix(self, form):
|
| 59 |
+
return cal_sparse_inter_matrix(self, form)
|
| 60 |
+
|
| 61 |
+
def log_info(self, name, interactions, list_u, list_i):
|
| 62 |
+
info = [self.dataset_name]
|
| 63 |
+
inter_num = len(interactions)
|
| 64 |
+
num_u = len(set(list_u))
|
| 65 |
+
num_i = len(set(list_i))
|
| 66 |
+
info.extend(['The number of users: {}'.format(num_u),
|
| 67 |
+
'Average actions of users: {}'.format(inter_num / num_u)])
|
| 68 |
+
info.extend(['The number of items: {}'.format(num_i),
|
| 69 |
+
'Average actions of items: {}'.format(inter_num / num_i)])
|
| 70 |
+
info.append('The number of inters: {}'.format(inter_num))
|
| 71 |
+
sparsity = 1 - inter_num / num_u / num_i
|
| 72 |
+
info.append('The sparsity of the dataset: {}%'.format(sparsity * 100))
|
| 73 |
+
self.logger.info('\n====' + name + '====\n' + str('\n'.join(info)))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Load_dataset(BaseDataset):
|
| 77 |
+
def __init__(self, config):
|
| 78 |
+
super().__init__(config)
|
| 79 |
+
self.item_knn_k = config.item_knn_k
|
| 80 |
+
self.user_knn_k = config.user_knn_k
|
| 81 |
+
self.i_mm_image_weight = config.i_mm_image_weight
|
| 82 |
+
self.u_mm_image_weight = config.u_mm_image_weight
|
| 83 |
+
self.all_set = set(range(self.n_users, self.n_users + self.n_items))
|
| 84 |
+
# Print statistical information
|
| 85 |
+
self.log_info("Training", self.train_data, self.train_data[:, 0], self.train_data[:, 1])
|
| 86 |
+
|
| 87 |
+
# ***************************************************************************************
|
| 88 |
+
# Prepare four graphs that will be needed later
|
| 89 |
+
# (user co-occurrence graph, user interest graph, item co-occurrence graph, item semantic graph)
|
| 90 |
+
|
| 91 |
+
# Construct a user co-occurrence matrix with several items of common interaction between all users
|
| 92 |
+
self.user_co_occ_matrix = load_or_create_matrix(self.logger, "User", " co-occurrence matrix",
|
| 93 |
+
self.dataset_name, "user_co_occ_matrix", creat_co_occur_matrix,
|
| 94 |
+
"user", self.train_data, 0, self.n_users)
|
| 95 |
+
# Construct an item co-occurrence matrix with several users who interact in common between all items
|
| 96 |
+
self.item_co_occ_matrix = load_or_create_matrix(self.logger, "Item", " co-occurrence matrix",
|
| 97 |
+
self.dataset_name, "item_co_occ_matrix", creat_co_occur_matrix,
|
| 98 |
+
"item", self.train_data, self.n_users, self.n_items)
|
| 99 |
+
|
| 100 |
+
# Construct a dictionary of user graphs, taking the first 200
|
| 101 |
+
self.dict_user_co_occ_graph = load_or_create_matrix(self.logger, "User", " co-occurrence dict graph",
|
| 102 |
+
self.dataset_name, "dict_user_co_occ_graph",
|
| 103 |
+
creat_dict_graph,
|
| 104 |
+
self.user_co_occ_matrix, self.n_users)
|
| 105 |
+
# Construct a dictionary of item graphs, taking the first 200
|
| 106 |
+
self.dict_item_co_occ_graph = load_or_create_matrix(self.logger, "Item", " co-occurrence dict graph",
|
| 107 |
+
self.dataset_name, "dict_item_co_occ_graph",
|
| 108 |
+
creat_dict_graph,
|
| 109 |
+
self.item_co_occ_matrix, self.n_items)
|
| 110 |
+
# ***************************************************************************************
|
| 111 |
+
|
| 112 |
+
# Get the sparse interaction matrix of the training set
|
| 113 |
+
sp_inter_m = sparse_mx_to_torch_sparse_tensor(self.sparse_inter_matrix(form='coo')).to(self.device)
|
| 114 |
+
# Construct a item weight graph
|
| 115 |
+
if self.i_v_feat is not None: # 4096
|
| 116 |
+
# Construct user visual interest similarity graphs
|
| 117 |
+
self.u_v_interest = tsp.mm(sp_inter_m, self.i_v_feat) / tsp.sum(sp_inter_m, [1]).unsqueeze(dim=1).to_dense()
|
| 118 |
+
u_v_adj = get_knn_adj_mat(self.u_v_interest, self.user_knn_k, self.device)
|
| 119 |
+
i_v_adj = get_knn_adj_mat(self.i_v_feat, self.item_knn_k, self.device)
|
| 120 |
+
self.i_mm_adj = i_v_adj
|
| 121 |
+
self.u_mm_adj = u_v_adj
|
| 122 |
+
if self.i_t_feat is not None: # 384
|
| 123 |
+
# Construct a user text interest similarity graph
|
| 124 |
+
self.u_t_interest = tsp.mm(sp_inter_m, self.i_t_feat) / tsp.sum(sp_inter_m, [1]).unsqueeze(dim=1).to_dense()
|
| 125 |
+
u_t_adj = get_knn_adj_mat(self.u_t_interest, self.user_knn_k, self.device)
|
| 126 |
+
i_t_adj = get_knn_adj_mat(self.i_t_feat, self.item_knn_k, self.device)
|
| 127 |
+
self.i_mm_adj = i_t_adj
|
| 128 |
+
self.u_mm_adj = u_t_adj
|
| 129 |
+
if self.i_v_feat is not None and self.i_t_feat is not None:
|
| 130 |
+
self.i_mm_adj = self.i_mm_image_weight * i_v_adj + (1.0 - self.i_mm_image_weight) * i_t_adj
|
| 131 |
+
self.u_mm_adj = self.u_mm_image_weight * u_v_adj + (1.0 - self.u_mm_image_weight) * u_t_adj
|
| 132 |
+
del i_t_adj, i_v_adj, u_t_adj, u_v_adj
|
| 133 |
+
torch.cuda.empty_cache()
|
| 134 |
+
|
| 135 |
+
# ***************************************************************************************
|
| 136 |
+
def __len__(self):
|
| 137 |
+
return len(self.train_data)
|
| 138 |
+
|
| 139 |
+
def __getitem__(self, index):
|
| 140 |
+
user, pos_item = self.train_data[index]
|
| 141 |
+
neg_item = random.sample(self.all_set - set(self.user_items_dict[user]), 1)[0]
|
| 142 |
+
return torch.LongTensor([user, user]), torch.LongTensor([pos_item, neg_item])
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Load_eval_dataset(BaseDataset):
|
| 146 |
+
def __init__(self, v_or_t, config, eval_dataset):
|
| 147 |
+
super().__init__(config)
|
| 148 |
+
self.eval_dataset = eval_dataset
|
| 149 |
+
self.step = config.eval_batch_size
|
| 150 |
+
self.inter_pr = 0 # Markup of the number of interactions that have been computed
|
| 151 |
+
self.eval_items_per_u = []
|
| 152 |
+
self.eval_len_list = []
|
| 153 |
+
self.train_pos_len_list = []
|
| 154 |
+
self.eval_u = list(set(eval_dataset[:, 0])) # Total users index
|
| 155 |
+
self.t_data = self.train_data
|
| 156 |
+
self.pos_items_per_u = self.train_items_per_u(self.eval_u)
|
| 157 |
+
self.evalute_items_per_u(self.eval_u)
|
| 158 |
+
|
| 159 |
+
self.s_idx = 0 # eval start index s_idx=pr
|
| 160 |
+
|
| 161 |
+
self.eval_users = len(set(eval_dataset[:, 0]))
|
| 162 |
+
self.eval_items = len(set(eval_dataset[:, 1]))
|
| 163 |
+
|
| 164 |
+
self.n_inters = eval_dataset.shape[0] # num_interactions n_inters=pr_end
|
| 165 |
+
# Print statistical information
|
| 166 |
+
self.log_info(v_or_t, self.eval_dataset, eval_dataset[:, 0], eval_dataset[:, 1])
|
| 167 |
+
|
| 168 |
+
def __len__(self):
|
| 169 |
+
return math.ceil(self.n_inters / self.step)
|
| 170 |
+
|
| 171 |
+
def __iter__(self):
|
| 172 |
+
return self
|
| 173 |
+
|
| 174 |
+
def __next__(self):
|
| 175 |
+
if self.s_idx >= self.n_inters:
|
| 176 |
+
self.s_idx = 0
|
| 177 |
+
self.inter_pr = 0
|
| 178 |
+
raise StopIteration()
|
| 179 |
+
return self._next_batch_data()
|
| 180 |
+
|
| 181 |
+
def _next_batch_data(self):
|
| 182 |
+
# Calculate the total number of interactions between the training set from A to B
|
| 183 |
+
inter_cnt = sum(self.train_pos_len_list[self.s_idx: self.s_idx + self.step])
|
| 184 |
+
batch_users = self.eval_u[self.s_idx: self.s_idx + self.step]
|
| 185 |
+
batch_mask_matrix = self.pos_items_per_u[:, self.inter_pr: self.inter_pr + inter_cnt].clone()
|
| 186 |
+
# user_ids to index(Always keep the index value at 0-self.step in preparation for evaluating the mask later on)
|
| 187 |
+
batch_mask_matrix[0] -= self.s_idx
|
| 188 |
+
self.inter_pr += inter_cnt # Update the starting index of the fetch data interaction data
|
| 189 |
+
self.s_idx += self.step # Update the starting index of the fetching user before fetching the data interaction data
|
| 190 |
+
|
| 191 |
+
return [batch_users, batch_mask_matrix]
|
| 192 |
+
|
| 193 |
+
def train_items_per_u(self, eval_users):
|
| 194 |
+
u_ids, i_ids = list(), list()
|
| 195 |
+
for i, u in enumerate(eval_users):
|
| 196 |
+
# Search for the number of items the training set has interacted with in order
|
| 197 |
+
u_ls = self.t_data[np.where(self.t_data[:, 0] == u), 1][0]
|
| 198 |
+
i_len = len(u_ls)
|
| 199 |
+
self.train_pos_len_list.append(i_len)
|
| 200 |
+
u_ids.extend([i] * i_len)
|
| 201 |
+
i_ids.extend(u_ls)
|
| 202 |
+
return torch.tensor([u_ids, i_ids]).type(torch.LongTensor)
|
| 203 |
+
|
| 204 |
+
def evalute_items_per_u(self, eval_users):
|
| 205 |
+
for u in eval_users:
|
| 206 |
+
u_ls = self.eval_dataset[np.where(self.eval_dataset[:, 0] == u), 1][0]
|
| 207 |
+
self.eval_len_list.append(len(u_ls))
|
| 208 |
+
self.eval_items_per_u.append(u_ls - self.n_users) # Items per user interaction
|
| 209 |
+
self.eval_len_list = np.asarray(self.eval_len_list)
|
utils/evaluator.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from utils.metrics import metrics_dict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def evaluate_model(self, epoch, eval_data, t_or_v):
|
| 7 |
+
self.model.eval()
|
| 8 |
+
with torch.no_grad():
|
| 9 |
+
batch_matrix_list = []
|
| 10 |
+
for batch_idx, batched_data in enumerate(eval_data):
|
| 11 |
+
scores = self.model.full_sort_predict(batched_data)
|
| 12 |
+
masked_items = batched_data[1]
|
| 13 |
+
scores[masked_items[0], masked_items[1] - self.model.n_users] = -1e10 # mask out pos items,restore ori_id
|
| 14 |
+
_, top_k_index = torch.topk(scores, max(self.topk), dim=-1) # nusers x topk
|
| 15 |
+
batch_matrix_list.append(top_k_index)
|
| 16 |
+
|
| 17 |
+
pos_items = eval_data.eval_items_per_u
|
| 18 |
+
pos_len_list = eval_data.eval_len_list
|
| 19 |
+
top_k_index = torch.cat(batch_matrix_list, dim=0).cpu().numpy()
|
| 20 |
+
assert len(pos_len_list) == len(top_k_index)
|
| 21 |
+
bool_rec_matrix = []
|
| 22 |
+
|
| 23 |
+
for m, n in zip(pos_items, top_k_index):
|
| 24 |
+
bool_rec_matrix.append([True if i in m else False for i in n])
|
| 25 |
+
bool_rec_matrix = np.asarray(bool_rec_matrix)
|
| 26 |
+
|
| 27 |
+
# get metrics
|
| 28 |
+
metric_dict = {}
|
| 29 |
+
result_list = cal_metrics(self.metrics, pos_len_list, bool_rec_matrix)
|
| 30 |
+
list_key = []
|
| 31 |
+
for metric, value in zip(self.metrics, result_list):
|
| 32 |
+
for k in self.topk:
|
| 33 |
+
key = '{}@{}'.format(metric, k)
|
| 34 |
+
list_key.append(key) if k == self.topk[-1] else None
|
| 35 |
+
metric_dict[key] = round(value[k - 1], 4) # Round to 4 decimal points
|
| 36 |
+
valid_score = metric_dict[self.valid_metric] if self.valid_metric else metric_dict['NDCG@20']
|
| 37 |
+
if self.writer is not None:
|
| 38 |
+
for idx in list_key:
|
| 39 |
+
self.writer.add_scalar(t_or_v + "_" + idx, metric_dict[idx], epoch) # Precision@20,Recall@20,NDCG@20
|
| 40 |
+
self.writer.add_histogram(t_or_v + '_user_visual_distribution', self.model.user_v_prefer, epoch)
|
| 41 |
+
self.writer.add_histogram(t_or_v + '_user_textual_distribution', self.model.user_t_prefer, epoch)
|
| 42 |
+
self.writer.add_embedding(self.model.user_id_embedding.weight, global_step=epoch,
|
| 43 |
+
tag=t_or_v + "user_id_embedding")
|
| 44 |
+
self.writer.add_embedding(self.model.item_id_embedding.weight, global_step=epoch,
|
| 45 |
+
tag=t_or_v + "item_id_embedding")
|
| 46 |
+
|
| 47 |
+
return valid_score, metric_dict
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def cal_metrics(topk_metrics, pos_len_list, topk_index):
|
| 51 |
+
result_list = []
|
| 52 |
+
for metric in topk_metrics:
|
| 53 |
+
metric_fuc = metrics_dict[metric]
|
| 54 |
+
result = metric_fuc(topk_index, pos_len_list)
|
| 55 |
+
result_list.append(result)
|
| 56 |
+
return np.stack(result_list, axis=0)
|
utils/helper.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
import scipy.sparse as sp
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from time import time
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
from scipy.sparse import coo_matrix
|
| 13 |
+
from torch.nn.functional import cosine_similarity
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def update_dict(key_ui, dataset, edge_dict):
|
| 17 |
+
for edge in dataset:
|
| 18 |
+
user, item = edge
|
| 19 |
+
edge_dict[user].add(item) if key_ui == "user" else None
|
| 20 |
+
edge_dict[item].add(user) if key_ui == "item" else None
|
| 21 |
+
return edge_dict
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_local_time():
|
| 25 |
+
return datetime.datetime.now().strftime('%b-%d-%Y-%H-%M-%S')
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def cal_reg_loss(cal_embedding):
|
| 29 |
+
return (cal_embedding.norm(2).pow(2)) / cal_embedding.size()[0]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def cal_cos_loss(user, item):
|
| 33 |
+
return 1 - cosine_similarity(user, item, dim=-1).mean()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def init_seed(seed):
|
| 37 |
+
if torch.cuda.is_available():
|
| 38 |
+
torch.cuda.manual_seed(seed)
|
| 39 |
+
torch.cuda.manual_seed_all(seed)
|
| 40 |
+
random.seed(seed)
|
| 41 |
+
np.random.seed(seed)
|
| 42 |
+
torch.manual_seed(seed)
|
| 43 |
+
torch.backends.cudnn.benchmark = False
|
| 44 |
+
torch.backends.cudnn.deterministic = True
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def early_stopping(value, best, cur_step, max_step):
|
| 48 |
+
stop_flag = False
|
| 49 |
+
update_flag = False
|
| 50 |
+
|
| 51 |
+
if value > best:
|
| 52 |
+
cur_step = 0
|
| 53 |
+
best = value
|
| 54 |
+
update_flag = True
|
| 55 |
+
else:
|
| 56 |
+
cur_step += 1
|
| 57 |
+
if cur_step > max_step:
|
| 58 |
+
stop_flag = True
|
| 59 |
+
|
| 60 |
+
return best, cur_step, stop_flag, update_flag
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def dict2str(result_dict):
|
| 64 |
+
result_str = ''
|
| 65 |
+
for metric, value in result_dict.items():
|
| 66 |
+
result_str += str(metric) + ': ' + '%.04f' % value + ' '
|
| 67 |
+
return result_str
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def res_output(epoch_idx, s_time, e_time, res, t_or_v):
|
| 71 |
+
_output = '\n epoch %d ' % epoch_idx + t_or_v + 'ing [time: %.2fs], ' % (e_time - s_time)
|
| 72 |
+
if t_or_v == "train":
|
| 73 |
+
_output += 'total_loss: {:.4f}, bpr_loss: {:.4f}, reg_loss:{:.4f}, mul_vt_cl_loss: {:.4f}, diff_loss: {:.4f}'.format(
|
| 74 |
+
res[0], res[1], res[2], res[3], res[4])
|
| 75 |
+
elif t_or_v == "valid":
|
| 76 |
+
_output += ' valid result: \n' + dict2str(res)
|
| 77 |
+
else:
|
| 78 |
+
_output += ' test result: \n' + dict2str(res)
|
| 79 |
+
return _output
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_parameter_number(self):
|
| 83 |
+
self.logger.info(self.model)
|
| 84 |
+
# Print the number of model parameters
|
| 85 |
+
total_num = sum(p.numel() for p in self.model.parameters())
|
| 86 |
+
trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 87 |
+
self.logger.info('Total parameters: {}, Trainable parameters: {}'.format(total_num, trainable_num))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_norm_adj_mat(self, interaction_matrix):
|
| 91 |
+
adj_size = (self.n_users + self.n_items, self.n_users + self.n_items)
|
| 92 |
+
A = sp.dok_matrix(adj_size, dtype=np.float32)
|
| 93 |
+
inter_M = interaction_matrix
|
| 94 |
+
inter_M_t = interaction_matrix.transpose()
|
| 95 |
+
# {(userID,itemID):1,(userID,itemID):1}
|
| 96 |
+
data_dict = dict(zip(zip(inter_M.row, inter_M.col + self.n_users), [1] * inter_M.nnz))
|
| 97 |
+
data_dict.update(dict(zip(zip(inter_M_t.row + self.n_users, inter_M_t.col),
|
| 98 |
+
[1] * inter_M_t.nnz)))
|
| 99 |
+
A._update(data_dict) # Update to (n_users+n_items)*(n_users+n_items) sparse matrix
|
| 100 |
+
adj = sparse_mx_to_torch_sparse_tensor(A).to(self.device)
|
| 101 |
+
return torch_sparse_tensor_norm_adj(adj, adj, adj_size, self.device)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def cal_sparse_inter_matrix(self, form='coo'):
|
| 105 |
+
src = self.train_dataset[0, :]
|
| 106 |
+
tgt = self.train_dataset[1, :]
|
| 107 |
+
data = np.ones(len(self.train_dataset.transpose(1, 0)))
|
| 108 |
+
mat = coo_matrix((data, (src, tgt)), shape=(self.n_users, self.n_items))
|
| 109 |
+
|
| 110 |
+
if form == 'coo':
|
| 111 |
+
return mat
|
| 112 |
+
elif form == 'csr':
|
| 113 |
+
return mat.tocsr()
|
| 114 |
+
else:
|
| 115 |
+
raise NotImplementedError('sparse matrix format [{}] has not been implemented.'.format(form))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def ssl_loss(data1, data2, index, ssl_temp):
|
| 119 |
+
index = torch.unique(index)
|
| 120 |
+
embeddings1 = data1[index]
|
| 121 |
+
embeddings2 = data2[index]
|
| 122 |
+
norm_embeddings1 = F.normalize(embeddings1, p=2, dim=1)
|
| 123 |
+
norm_embeddings2 = F.normalize(embeddings2, p=2, dim=1)
|
| 124 |
+
pos_score_t = torch.sum(torch.mul(norm_embeddings1, norm_embeddings2), dim=1)
|
| 125 |
+
all_score = torch.mm(norm_embeddings1, norm_embeddings2.T)
|
| 126 |
+
pos_score = torch.exp(pos_score_t / ssl_temp)
|
| 127 |
+
all_score = torch.sum(torch.exp(all_score / ssl_temp), dim=1)
|
| 128 |
+
loss = (-torch.sum(torch.log(pos_score / all_score)) / (len(index)))
|
| 129 |
+
return loss
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def cal_diff_loss(feat, ui_index, dim):
|
| 133 |
+
"""
|
| 134 |
+
:param feat: uv_ut iv_it (n_users+n_items)*dim*2
|
| 135 |
+
:param ui_index: user or item index
|
| 136 |
+
:return: Squared Frobenius Norm Loss
|
| 137 |
+
"""
|
| 138 |
+
input1 = feat[ui_index[:, 0], :dim]
|
| 139 |
+
input2 = feat[ui_index[:, 0], dim:]
|
| 140 |
+
|
| 141 |
+
# Zero mean
|
| 142 |
+
input1_mean = torch.mean(input1, dim=0, keepdims=True)
|
| 143 |
+
input2_mean = torch.mean(input2, dim=0, keepdims=True)
|
| 144 |
+
|
| 145 |
+
input1 = input1 - input1_mean
|
| 146 |
+
input2 = input2 - input2_mean
|
| 147 |
+
|
| 148 |
+
input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True).detach()
|
| 149 |
+
input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6)
|
| 150 |
+
|
| 151 |
+
input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True).detach()
|
| 152 |
+
input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6)
|
| 153 |
+
|
| 154 |
+
loss = torch.mean((input1_l2.t().mm(input2_l2)).pow(2))
|
| 155 |
+
|
| 156 |
+
return loss
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def sele_para(config):
|
| 160 |
+
res = "\n *****************************************************************"
|
| 161 |
+
res += "***************************************************************** \n"
|
| 162 |
+
res += "l_r: " + str(config.learning_rate) + ", reg_w: " + str(config.reg_weight)
|
| 163 |
+
res += ", n_l: " + str(config.n_layers) + ", emb_dim: " + str(config.embedding_dim)
|
| 164 |
+
res += ", s_drop : " + str(config.s_drop) + ", m_drop : " + str(config.m_drop)
|
| 165 |
+
res += ", u_mm_v_w: " + str(config.u_mm_image_weight) + ", i_mm_v_w: " + str(config.i_mm_image_weight)
|
| 166 |
+
res += ", uu_co_w: " + str(config.uu_co_weight) + ", ii_co_w: " + str(config.ii_co_weight)
|
| 167 |
+
res += ", u_knn_k: " + str(config.user_knn_k) + ", i_knn_k: " + str(config.item_knn_k)
|
| 168 |
+
res += ", n_uu_layers: " + str(config.n_uu_layers) + ", n_ii_layers: " + str(config.n_ii_layers)
|
| 169 |
+
res += ", cl_temp: " + str(config.cl_tmp) + ", rank: " + str(config.rank)
|
| 170 |
+
res += ", cl_loss_w: " + str(config.cl_loss_weight) + ", diff_loss_w: " + str(config.diff_loss_weight)
|
| 171 |
+
return res + "\n"
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def update_result(self, test_result):
|
| 175 |
+
update_output = ' 🏃 🏃 🏃 🏆🏆🏆 🏃 🏃 🏃 ' + self.model_name + "_" + self.dataset_name + '--Best validation results updated!!!'
|
| 176 |
+
self.logger.info(update_output)
|
| 177 |
+
self.best_test_upon_valid = test_result
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def stop_log(self, epoch_idx, run_start_time):
|
| 181 |
+
stop_output = 'Finished training, best eval result in epoch %d' % (epoch_idx - self.cur_step)
|
| 182 |
+
stop_output += "\n [total time: %.2fmins], " % ((time() - run_start_time) / 60)
|
| 183 |
+
stop_output += '\n ' + str(self.config.dataset) + ' key parameter: ' + sele_para(self.config)
|
| 184 |
+
stop_output += 'test result: \n' + dict2str(self.best_test_upon_valid)
|
| 185 |
+
self.logger.info(stop_output)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def plot_curve(self, show=True, save_path=None):
|
| 189 |
+
epochs = list(self.train_loss_dict.keys())
|
| 190 |
+
epochs.sort()
|
| 191 |
+
train_loss_values = [float(self.train_loss_dict[epoch]) for epoch in epochs]
|
| 192 |
+
valid_result_values = [float(self.best_valid_result[epoch]) for epoch in epochs]
|
| 193 |
+
plt.plot(epochs, train_loss_values, label='train', color='red')
|
| 194 |
+
plt.plot(epochs, valid_result_values, label='valid', color='black')
|
| 195 |
+
plt.xticks(epochs)
|
| 196 |
+
plt.xlabel('Epoch')
|
| 197 |
+
plt.ylabel('Loss')
|
| 198 |
+
plt.title('Training loss and Validing result curves')
|
| 199 |
+
if show:
|
| 200 |
+
plt.show()
|
| 201 |
+
if save_path:
|
| 202 |
+
plt.savefig(save_path)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# Generated user or item co-occurrence matrix
|
| 206 |
+
def creat_co_occur_matrix(type_ui, all_edge, start_ui, num_ui):
|
| 207 |
+
"""
|
| 208 |
+
:param type_ui: Types of created co-occurrence graphs, {user, item}
|
| 209 |
+
:param all_edge: train data np.array([[0, 6], [0, 11], [0, 8], [1, 7]])
|
| 210 |
+
:param start_ui: Minimum or starting user or item index
|
| 211 |
+
:param num_ui:Total number of users or items
|
| 212 |
+
:return:Generated user or item co-occurrence matrix
|
| 213 |
+
"""
|
| 214 |
+
edge_dict = defaultdict(set)
|
| 215 |
+
|
| 216 |
+
for edge in all_edge:
|
| 217 |
+
user, item = edge
|
| 218 |
+
edge_dict[user].add(item) if type_ui == "user" else None
|
| 219 |
+
edge_dict[item].add(user) if type_ui == "item" else None
|
| 220 |
+
|
| 221 |
+
co_graph_matrix = torch.zeros(num_ui, num_ui)
|
| 222 |
+
key_list = sorted(list(edge_dict.keys()))
|
| 223 |
+
bar = tqdm(total=len(key_list))
|
| 224 |
+
for head in range(len(key_list)):
|
| 225 |
+
bar.update(1)
|
| 226 |
+
for rear in range(head + 1, len(key_list)):
|
| 227 |
+
head_key = key_list[head]
|
| 228 |
+
rear_key = key_list[rear]
|
| 229 |
+
ui_head = edge_dict[head_key]
|
| 230 |
+
ui_rear = edge_dict[rear_key]
|
| 231 |
+
inter_len = len(ui_head.intersection(ui_rear))
|
| 232 |
+
if inter_len > 0:
|
| 233 |
+
co_graph_matrix[head_key - start_ui][rear_key - start_ui] = inter_len
|
| 234 |
+
co_graph_matrix[rear_key - start_ui][head_key - start_ui] = inter_len
|
| 235 |
+
bar.close()
|
| 236 |
+
return co_graph_matrix
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def creat_dict_graph(co_graph_matrix, num_ui):
|
| 240 |
+
dict_graph = {}
|
| 241 |
+
for i in tqdm(range(num_ui)):
|
| 242 |
+
num_co_ui = len(torch.nonzero(co_graph_matrix[i]))
|
| 243 |
+
|
| 244 |
+
if num_co_ui <= 200:
|
| 245 |
+
topk_ui = torch.topk(co_graph_matrix[i], num_co_ui)
|
| 246 |
+
edge_list_i = topk_ui.indices.tolist()
|
| 247 |
+
edge_list_j = topk_ui.values.tolist()
|
| 248 |
+
edge_list = [edge_list_i, edge_list_j]
|
| 249 |
+
dict_graph[i] = edge_list
|
| 250 |
+
else:
|
| 251 |
+
topk_ui = torch.topk(co_graph_matrix[i], 200)
|
| 252 |
+
edge_list_i = topk_ui.indices.tolist()
|
| 253 |
+
edge_list_j = topk_ui.values.tolist()
|
| 254 |
+
edge_list = [edge_list_i, edge_list_j]
|
| 255 |
+
dict_graph[i] = edge_list
|
| 256 |
+
return dict_graph
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Calculate item similarity, build similarity matrix
|
| 260 |
+
def get_knn_adj_mat(mm_embeddings, knn_k, device):
|
| 261 |
+
# Standardize and calculate similarity
|
| 262 |
+
context_norm = F.normalize(mm_embeddings, dim=1)
|
| 263 |
+
final_sim = torch.mm(context_norm, context_norm.transpose(1, 0)).cpu()
|
| 264 |
+
sim_value, knn_ind = torch.topk(final_sim, knn_k, dim=-1)
|
| 265 |
+
adj_size = final_sim.size()
|
| 266 |
+
# Construct sparse adjacency matrices
|
| 267 |
+
indices0 = torch.arange(knn_ind.shape[0])
|
| 268 |
+
indices0 = torch.unsqueeze(indices0, 1)
|
| 269 |
+
indices0 = indices0.expand(-1, knn_k)
|
| 270 |
+
indices = torch.stack((torch.flatten(indices0), torch.flatten(knn_ind)), 0)
|
| 271 |
+
sim_adj = torch.sparse.FloatTensor(indices, sim_value.flatten(), adj_size).to(device)
|
| 272 |
+
degree_adj = torch.sparse.FloatTensor(indices, torch.ones(indices.shape[1]), adj_size)
|
| 273 |
+
return torch_sparse_tensor_norm_adj(sim_adj, degree_adj, adj_size, device)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def torch_sparse_tensor_norm_adj(sim_adj, degree_adj, adj_size, device):
|
| 277 |
+
"""
|
| 278 |
+
:param sim_adj: Tensor adjacency matrix (The value of 0 or 1 is degree normalised; the value of [0,1] is similarity normalised)
|
| 279 |
+
:param degree_adj: Tensor adjacency matrix (The value of 0 or 1 is degree normalised; the value of [0,1] is similarity normalised)
|
| 280 |
+
:param adj_size: Tensor size of adjacency matrix
|
| 281 |
+
:param device: cpu or gpu
|
| 282 |
+
:return: Laplace degree normalised adjacency matrix
|
| 283 |
+
"""
|
| 284 |
+
# norm adj matrix,add epsilon to avoid Devide by zero Warning
|
| 285 |
+
row_sum = 1e-7 + torch.sparse.sum(degree_adj, -1).to_dense()
|
| 286 |
+
r_inv_sqrt = torch.pow(row_sum, -0.5)
|
| 287 |
+
|
| 288 |
+
col = torch.arange(adj_size[0])
|
| 289 |
+
row = torch.arange(adj_size[1])
|
| 290 |
+
sp_degree = torch.sparse.FloatTensor(torch.stack((col, row)).to(device), r_inv_sqrt.to(device))
|
| 291 |
+
return torch.spmm((torch.spmm(sp_degree, sim_adj)), sp_degree)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
|
| 295 |
+
"""Convert a scipy sparse matrix to a torch sparse tensor."""
|
| 296 |
+
if type(sparse_mx) != sp.coo_matrix:
|
| 297 |
+
sparse_mx = sparse_mx.tocoo().astype(np.float32)
|
| 298 |
+
indices = torch.from_numpy(
|
| 299 |
+
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
|
| 300 |
+
values = torch.from_numpy(sparse_mx.data).float()
|
| 301 |
+
shape = torch.Size(sparse_mx.shape)
|
| 302 |
+
return torch.sparse.FloatTensor(indices, values, shape)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def topk_sample(n_ui, dict_graph, k, topK_ui, topK_ui_counts, aggr_mode, device):
|
| 306 |
+
ui_graph_index = []
|
| 307 |
+
user_weight_matrix = torch.zeros(len(dict_graph), k)
|
| 308 |
+
for i in range(len(dict_graph)):
|
| 309 |
+
|
| 310 |
+
if len(dict_graph[i][0]) < k:
|
| 311 |
+
if len(dict_graph[i][0]) != 0:
|
| 312 |
+
|
| 313 |
+
ui_graph_sample = dict_graph[i][0][:k]
|
| 314 |
+
ui_graph_weight = dict_graph[i][1][:k]
|
| 315 |
+
rand_index = np.random.randint(0, len(ui_graph_sample), size=k - len(ui_graph_sample))
|
| 316 |
+
ui_graph_sample += np.array(ui_graph_sample)[rand_index].tolist()
|
| 317 |
+
ui_graph_weight += np.array(ui_graph_weight)[rand_index].tolist()
|
| 318 |
+
ui_graph_index.append(ui_graph_sample)
|
| 319 |
+
else:
|
| 320 |
+
ui_graph_index.append(topK_ui[:k])
|
| 321 |
+
ui_graph_weight = (np.array(topK_ui_counts[:k]) / sum(topK_ui_counts[:k])).tolist()
|
| 322 |
+
else:
|
| 323 |
+
ui_graph_sample = dict_graph[i][0][:k]
|
| 324 |
+
ui_graph_weight = dict_graph[i][1][:k]
|
| 325 |
+
ui_graph_index.append(ui_graph_sample)
|
| 326 |
+
|
| 327 |
+
if aggr_mode == 'softmax':
|
| 328 |
+
user_weight_matrix[i] = F.softmax(torch.tensor(ui_graph_weight), dim=0) # softmax
|
| 329 |
+
elif aggr_mode == 'mean':
|
| 330 |
+
user_weight_matrix[i] = torch.ones(k) / k # mean
|
| 331 |
+
|
| 332 |
+
tmp_all_row = []
|
| 333 |
+
tmp_all_col = []
|
| 334 |
+
for i in range(n_ui):
|
| 335 |
+
row = torch.zeros(1, k) + i
|
| 336 |
+
tmp_all_row += row.flatten()
|
| 337 |
+
tmp_all_col += ui_graph_index[i]
|
| 338 |
+
tmp_all_row = torch.tensor(tmp_all_row).to(torch.int32)
|
| 339 |
+
tmp_all_col = torch.tensor(tmp_all_col).to(torch.int32)
|
| 340 |
+
values = user_weight_matrix.flatten().to(device)
|
| 341 |
+
indices = torch.stack((tmp_all_row, tmp_all_col)).to(device)
|
| 342 |
+
return torch.sparse_coo_tensor(indices, values, (n_ui, n_ui))
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def load_or_create_matrix(logger, matrix_type, des, dataset_name, file_name, create_function, *create_args):
|
| 346 |
+
"""
|
| 347 |
+
Load a matrix from file if it exists; otherwise, create and save it.
|
| 348 |
+
:param logger: logger
|
| 349 |
+
:param matrix_type: str, type of the matrix (e.g., 'user', 'item').
|
| 350 |
+
:param des: str name of the matrix
|
| 351 |
+
:param dataset_name: str, dataset name used to define the file path.
|
| 352 |
+
:param file_name: str, name of the file to save or load the matrix.
|
| 353 |
+
:param create_function: function, function to call for matrix creation.
|
| 354 |
+
:param create_args: tuple, additional arguments for the create function.
|
| 355 |
+
:return: The loaded or created matrix.
|
| 356 |
+
"""
|
| 357 |
+
file_path = os.path.join("data", dataset_name, file_name + ".pt")
|
| 358 |
+
|
| 359 |
+
if os.path.exists(file_path):
|
| 360 |
+
matrix = torch.load(file_path)
|
| 361 |
+
logger.info(f"{matrix_type.capitalize()} " + des + " has been loaded!")
|
| 362 |
+
else:
|
| 363 |
+
logger.info(f"{matrix_type.capitalize()} " + des + " does not exist, creating!")
|
| 364 |
+
matrix = create_function(*create_args)
|
| 365 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 366 |
+
torch.save(matrix, file_path)
|
| 367 |
+
logger.info(f"{matrix_type.capitalize()} " + des + " has been created and saved!")
|
| 368 |
+
return matrix
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def propgt_info(ego_feat, n_layers, sp_mat, last_layer=False):
|
| 372 |
+
all_feat = [ego_feat]
|
| 373 |
+
for _ in range(n_layers):
|
| 374 |
+
ego_feat = torch.sparse.mm(sp_mat, ego_feat)
|
| 375 |
+
all_feat += [ego_feat]
|
| 376 |
+
if last_layer:
|
| 377 |
+
return ego_feat
|
| 378 |
+
|
| 379 |
+
all_feat = torch.stack(all_feat, dim=1)
|
| 380 |
+
all_feat = all_feat.mean(dim=1, keepdim=False)
|
| 381 |
+
return all_feat
|
utils/logger.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def init_logger(config):
|
| 6 |
+
"""
|
| 7 |
+
A logger that can show a message on standard output and write it into the
|
| 8 |
+
file named `filename` simultaneously.
|
| 9 |
+
All the message that you want to log MUST be str.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
config (Config): An instance object of Config, used to record parameter information.
|
| 13 |
+
"""
|
| 14 |
+
LOGROOT = './log/'
|
| 15 |
+
dir_name = os.path.dirname(LOGROOT)
|
| 16 |
+
if not os.path.exists(dir_name):
|
| 17 |
+
os.makedirs(dir_name)
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger("normal")
|
| 20 |
+
|
| 21 |
+
name_ = "{}-{}-lr_{}-rww_{}-nl_{}-sdp_{}-mdp_{}-clt_{}-diffw_{}-semw_{}.log"
|
| 22 |
+
logfilename = name_.format(config.model_name, config.dataset, config.learning_rate,
|
| 23 |
+
config.reg_weight, config.n_layers,
|
| 24 |
+
config.s_drop, config.m_drop, config.cl_tmp,
|
| 25 |
+
config.diff_loss_weight, config.cl_loss_weight)
|
| 26 |
+
logfilepath = os.path.join(LOGROOT, logfilename)
|
| 27 |
+
filefmt = "%(asctime)-15s %(message)s"
|
| 28 |
+
filedatefmt = "%a %d %b %Y %H:%M:%S"
|
| 29 |
+
|
| 30 |
+
fileformatter = logging.Formatter(filefmt, filedatefmt)
|
| 31 |
+
|
| 32 |
+
sfmt = u"%(asctime)-15s %(message)s"
|
| 33 |
+
sdatefmt = "%d %b %H:%M"
|
| 34 |
+
sformatter = logging.Formatter(sfmt, sdatefmt)
|
| 35 |
+
|
| 36 |
+
fh = logging.FileHandler(logfilepath, 'w', 'utf-8')
|
| 37 |
+
fh.setFormatter(fileformatter)
|
| 38 |
+
|
| 39 |
+
sh = logging.StreamHandler()
|
| 40 |
+
sh.setFormatter(sformatter)
|
| 41 |
+
|
| 42 |
+
logger.setLevel(logging.INFO)
|
| 43 |
+
logger.handlers = []
|
| 44 |
+
logger.addHandler(fh)
|
| 45 |
+
logger.addHandler(sh)
|
| 46 |
+
return logger
|
utils/metrics.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def cal_recall(pos_index, pos_len):
|
| 5 |
+
rec_ret = np.cumsum(pos_index, axis=1) / pos_len.reshape(-1, 1)
|
| 6 |
+
return rec_ret.mean(axis=0)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def cal_ndcg(pos_index, pos_len):
|
| 10 |
+
len_rank = np.full_like(pos_len, pos_index.shape[1])
|
| 11 |
+
idcg_len = np.where(pos_len > len_rank, len_rank, pos_len)
|
| 12 |
+
|
| 13 |
+
iranks = np.zeros_like(pos_index, dtype=float)
|
| 14 |
+
iranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
|
| 15 |
+
idcg = np.cumsum(1.0 / np.log2(iranks + 1), axis=1)
|
| 16 |
+
for row, idx in enumerate(idcg_len):
|
| 17 |
+
idcg[row, idx:] = idcg[row, idx - 1]
|
| 18 |
+
|
| 19 |
+
ranks = np.zeros_like(pos_index, dtype=float)
|
| 20 |
+
ranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
|
| 21 |
+
dcg = 1.0 / np.log2(ranks + 1)
|
| 22 |
+
dcg = np.cumsum(np.where(pos_index, dcg, 0), axis=1)
|
| 23 |
+
|
| 24 |
+
result = dcg / idcg
|
| 25 |
+
return result.mean(axis=0)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def cal_map(pos_index, pos_len):
|
| 29 |
+
pre = pos_index.cumsum(axis=1) / np.arange(1, pos_index.shape[1] + 1)
|
| 30 |
+
sum_pre = np.cumsum(pre * pos_index.astype(float), axis=1)
|
| 31 |
+
len_rank = np.full_like(pos_len, pos_index.shape[1])
|
| 32 |
+
actual_len = np.where(pos_len > len_rank, len_rank, pos_len)
|
| 33 |
+
result = np.zeros_like(pos_index, dtype=float)
|
| 34 |
+
for row, lens in enumerate(actual_len):
|
| 35 |
+
ranges = np.arange(1, pos_index.shape[1] + 1)
|
| 36 |
+
ranges[lens:] = ranges[lens - 1]
|
| 37 |
+
result[row] = sum_pre[row] / ranges
|
| 38 |
+
return result.mean(axis=0)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def cal_precision(pos_index, pos_len):
|
| 42 |
+
rec_ret = pos_index.cumsum(axis=1) / np.arange(1, pos_index.shape[1] + 1)
|
| 43 |
+
return rec_ret.mean(axis=0)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
"""Function name and function mapper.
|
| 47 |
+
Useful when we have to serialize evaluation metric names
|
| 48 |
+
and call the functions based on deserialized names
|
| 49 |
+
"""
|
| 50 |
+
metrics_dict = {
|
| 51 |
+
'Precision': cal_precision,
|
| 52 |
+
'Recall': cal_recall,
|
| 53 |
+
'NDCG': cal_ndcg,
|
| 54 |
+
'MAP': cal_map,
|
| 55 |
+
}
|
utils/parser.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def parse_args():
|
| 5 |
+
parser = argparse.ArgumentParser(description="Run REARM.")
|
| 6 |
+
parser.add_argument('--seed', type=int, default=2025, help='Seed init.')
|
| 7 |
+
parser.add_argument('--model_name', default='REARM', help='Model name.')
|
| 8 |
+
parser.add_argument('--use_gpu', type=bool, default=True, help='enable CUDA training.')
|
| 9 |
+
parser.add_argument('--gpu_id', type=int, default=0, help='The model of the device running the program')
|
| 10 |
+
parser.add_argument('--dataset', nargs='?', default='baby',
|
| 11 |
+
help='Choose a dataset from {baby, sports, clothing}')
|
| 12 |
+
parser.add_argument('--batch_size', type=int, default=2048, help='Batch size.')
|
| 13 |
+
parser.add_argument('--eval_batch_size', type=int, default=8192, help='The data size of batch evaluation')
|
| 14 |
+
parser.add_argument('--metrics', type=list, default=["Precision", "Recall", "NDCG"],
|
| 15 |
+
help='Choose some from {"Precision", "Recall", "NDCG", "MAP"}')
|
| 16 |
+
parser.add_argument('--topk', type=list, default=[10, 20], help='Metrics scale')
|
| 17 |
+
parser.add_argument('--embedding_dim', type=int, default=64, help='Latent dimension 64.')
|
| 18 |
+
parser.add_argument('--num_epoch', type=int, default=2000, help='Epoch number.')
|
| 19 |
+
parser.add_argument('--num_workers', type=int, default=8, help='Workers number.')
|
| 20 |
+
parser.add_argument('--stopping_step', type=int, default=20, help='early stopping strategy.')
|
| 21 |
+
parser.add_argument('--valid_metric', type=str, default="Recall@20", help='valid metric')
|
| 22 |
+
parser.add_argument('--with_tensorboard', action='store_true', default=False, help='with tensorboard analysis ')
|
| 23 |
+
|
| 24 |
+
parser.add_argument('--l_r', type=float, default=5e-5, help='Learning rate.')
|
| 25 |
+
parser.add_argument('--learning_rate_scheduler', type=list, default=[1.0, 50], help='learning rate scheduler.')
|
| 26 |
+
parser.add_argument('--reg_weight', type=float, default=5e-4, help='regularization weight.')
|
| 27 |
+
parser.add_argument('--num_layer', type=int, default=4, help='Layer number.')
|
| 28 |
+
parser.add_argument('--s_drop', type=float, default=0.4, help='self_attention_dropout.')
|
| 29 |
+
parser.add_argument('--m_drop', type=float, default=0.6, help='mutual_attention_dropout.')
|
| 30 |
+
parser.add_argument('--cl_tmp', type=float, default=0.6, help='Contrast learning temperature coefficient')
|
| 31 |
+
parser.add_argument('--cl_loss_weight', type=float, default=5e-6, help='contrast loss weight.')
|
| 32 |
+
parser.add_argument('--diff_loss_weight', type=float, default=1e-4, help='Structure contrast loss weight.')
|
| 33 |
+
parser.add_argument('--user_knn_k', type=int, default=40,
|
| 34 |
+
help='Select the 10 users most similar to the target users to build the users graph')
|
| 35 |
+
parser.add_argument('--item_knn_k', type=int, default=10,
|
| 36 |
+
help='Select the 10 items most similar to the target item to build the item graph')
|
| 37 |
+
|
| 38 |
+
parser.add_argument('--i_mm_image_weight', type=float, default=0,
|
| 39 |
+
help='The proportion of visual feat in item graph.')
|
| 40 |
+
parser.add_argument('--u_mm_image_weight', type=float, default=0.2,
|
| 41 |
+
help='The proportion of visual feat in user graph.')
|
| 42 |
+
parser.add_argument('--n_ii_layers', type=int, default=1,
|
| 43 |
+
help='Number of layers of item feature propagation in the item graph')
|
| 44 |
+
parser.add_argument('--n_uu_layers', type=int, default=1,
|
| 45 |
+
help='Number of layers of user feature propagation in the user graph')
|
| 46 |
+
parser.add_argument('--user_aggr_mode', type=str, default='softmax',
|
| 47 |
+
help='Choose a modedataset from {softmax, mean}')
|
| 48 |
+
|
| 49 |
+
parser.add_argument('--rank', type=int, default=3, help='the dimension of low rank matrix decomposition')
|
| 50 |
+
parser.add_argument('--uu_co_weight', type=float, default=0.4,
|
| 51 |
+
help='the proportion of user co-occurrence graphs to user homographs')
|
| 52 |
+
parser.add_argument('--ii_co_weight', type=float, default=0.2,
|
| 53 |
+
help='the proportion of item co-occurrence graphs to user homographs')
|
| 54 |
+
return parser.parse_args()
|