MrShouxingMa commited on
Commit
f60c555
·
verified ·
1 Parent(s): 8c19a5e

Upload 19 files

Browse files
data/baby/test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b890fa339aca21fac9c17c27a9f1ea163ff8df9a6e4caf353c7f7cf7d689c745
3
+ size 347040
data/baby/train.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d4c495814d61a1ea84b756f11096d468cc07113d05aecbabcef8de9cf7a1387
3
+ size 1896944
data/baby/valid.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:388bd7922d74311e6afc6fbc9f6299cc85b32310109f7d67939ef90e097babb2
3
+ size 329072
data/clothing/test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:272ba8177a70a973842bc30eafe0d7ae7d641c5117600a0f23b24319c7ad7a61
3
+ size 659152
data/clothing/train.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:685fe396e4a8d3c6cf465a9e8421e1fa07b7d187e7fa1dc996578f11a7de896a
3
+ size 3157536
data/clothing/valid.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f8c581fa0680e5886d4781a4a551868d47138c9b2d4d1cbe886d17cef3fedcd
3
+ size 642528
data/sports/test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9396ff247c58b918a915bc95f19574492c0cf3f8ae104b77ac2447b15f77970
3
+ size 640592
data/sports/train.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1f8c7a04efd251941083dd25bff7c154e2e2ff7bc3ccc240d45bd11f906fbf2
3
+ size 3494672
data/sports/valid.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a27b2d104aef08582fbf241dc3c77b9ce7fc0788962ec6f36eb9b8c0dafba649
3
+ size 606512
main.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import platform
4
+ from time import time
5
+ from tqdm import tqdm
6
+ from trainer import train
7
+ import torch.optim as optim
8
+ from utils.parser import parse_args
9
+ from utils.logger import init_logger
10
+ from utils.configurator import Config
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from utils.evaluator import evaluate_model
14
+ from utils.data_loader import Load_dataset, Load_eval_dataset
15
+ from utils.helper import early_stopping, plot_curve, res_output, stop_log, update_result, sele_para
16
+ from model import REARM
17
+
18
+
19
+ class Net:
20
+ def __init__(self, args):
21
+ # Complete initialization of all parameters (including random seeds)
22
+ self.config = Config(args)
23
+ # Use logger
24
+ self.logger = init_logger(self.config)
25
+ self.logger.info(self.config)
26
+ self.logger.info('██Server: \t' + platform.node())
27
+ self.logger.info('██Dir: \t' + os.getcwd() + '\n')
28
+ self.device = self.config.device
29
+ self.model_name = self.config.model_name
30
+ self.dataset_name = self.config.dataset
31
+ self.batch_size = self.config.batch_size
32
+ self.num_workers = self.config.num_workers
33
+ self.learning_rate = self.config.learning_rate
34
+ self.num_epoch = self.config.num_epoch
35
+ self.topk = self.config.topk
36
+ self.metrics = self.config.metrics
37
+ self.valid_metric = self.config.valid_metric
38
+ self.stopping_step = self.config.stopping_step
39
+ self.reg_weight = self.config.reg_weight
40
+ self.cur_step = 0
41
+ self.best_valid_score = -1
42
+ self.best_valid_result = {}
43
+ self.best_test_upon_valid = {}
44
+ # Writer will output to ./runs/ directory by default
45
+ self.writer = SummaryWriter() if self.config.writer else None
46
+
47
+ # Perform experimental configurations
48
+ Dataset = Load_dataset(self.config)
49
+ valid_dataset, test_dataset = Dataset.load_eval_data()
50
+ self.train_data = DataLoader(Dataset, batch_size=self.batch_size, shuffle=True,
51
+ num_workers=self.num_workers)
52
+
53
+ (self.valid_data, self.test_data) = (Load_eval_dataset("Validation", self.config, valid_dataset),
54
+ Load_eval_dataset("Testing", self.config, test_dataset))
55
+ self.model = REARM(self.config, Dataset).to(self.device)
56
+ self.optimizer = optim.AdamW(self.model.parameters(), self.learning_rate, weight_decay=self.reg_weight)
57
+ lr_scheduler = self.config.learning_rate_scheduler
58
+ fac = lambda epoch: lr_scheduler[0] ** (epoch / lr_scheduler[1])
59
+ scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac)
60
+ self.lr_scheduler = scheduler
61
+ self.logger.info(self.model)
62
+
63
+ def plot_train_loss(self):
64
+ plot_curve(self)
65
+
66
+ def run(self):
67
+ run_start_time = time()
68
+ for epoch_idx in tqdm(range(self.num_epoch)):
69
+ train_start_time = time()
70
+ train_loss = train(self, epoch_idx)
71
+ # Save if an exception occurs
72
+ if torch.isnan(train_loss[0]):
73
+ ret_value = {"Recall@20": -1} if self.best_test_upon_valid == {} else self.best_test_upon_valid
74
+ stop_output = '\n ' + str(self.config.dataset) + ' key parameter: ' + sele_para(self.config)
75
+ self.logger.info(stop_output)
76
+ self.logger.info('Loss is nan at epoch: {}; last value is {}Exiting.'.format(epoch_idx, ret_value))
77
+ return ret_value
78
+
79
+ self.lr_scheduler.step()
80
+
81
+ train_output = res_output(epoch_idx, train_start_time, time(), train_loss, "train")
82
+ self.logger.info(train_output)
83
+
84
+ # valid evaluate_model
85
+ valid_start_time = time()
86
+ valid_score, valid_result = evaluate_model(self, epoch_idx, self.valid_data, t_or_v="valid")
87
+
88
+ self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
89
+ valid_score, self.best_valid_score, self.cur_step, self.stopping_step)
90
+
91
+ self.best_valid_result[epoch_idx] = self.best_valid_score
92
+ valid_output = res_output(epoch_idx, valid_start_time, time(), valid_result, t_or_v="valid")
93
+ self.logger.info(valid_output)
94
+
95
+ if update_flag:
96
+ # test evaluate_model
97
+ test_start_time = time()
98
+ _, test_result = evaluate_model(self, epoch_idx, self.test_data, t_or_v="test")
99
+ test_score_output = res_output(epoch_idx, test_start_time, time(), test_result, t_or_v="test")
100
+ self.logger.info(test_score_output)
101
+ update_result(self, test_result)
102
+
103
+ if stop_flag:
104
+ stop_log(self, epoch_idx, run_start_time)
105
+ break
106
+ else:
107
+ print('patience ==> %d' % (self.stopping_step - self.cur_step))
108
+ return self.best_test_upon_valid
109
+
110
+
111
+ if __name__ == '__main__':
112
+ _args = parse_args()
113
+ model = Net(_args)
114
+ best_score = model.run()
model.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from utils.helper import get_norm_adj_mat, ssl_loss, topk_sample, cal_diff_loss, propgt_info
5
+
6
+
7
+ class REARM(nn.Module):
8
+ def __init__(self, config, dataset):
9
+ super(REARM, self).__init__()
10
+
11
+ self.n_users = dataset.n_users
12
+ self.n_items = dataset.n_items
13
+ self.n_nodes = self.n_users + self.n_items
14
+ self.i_v_feat = dataset.i_v_feat
15
+ self.i_t_feat = dataset.i_t_feat
16
+ self.embedding_dim = config.embedding_dim
17
+ self.feat_embed_dim = config.embedding_dim
18
+ self.dim_feat = self.feat_embed_dim
19
+ self.reg_weight = config.reg_weight
20
+ self.device = config.device
21
+ self.cl_tmp = config.cl_tmp
22
+ self.cl_loss_weight = config.cl_loss_weight
23
+ self.diff_loss_weight = config.diff_loss_weight
24
+ self.n_layers = config.n_layers
25
+ self.num_user_co = config.num_user_co
26
+ self.num_item_co = config.num_item_co
27
+ self.user_aggr_mode = config.user_aggr_mode
28
+ self.n_ii_layers = config.n_ii_layers
29
+ self.n_uu_layers = config.n_uu_layers
30
+ self.k = config.rank
31
+ self.uu_co_weight = config.uu_co_weight
32
+ self.ii_co_weight = config.ii_co_weight
33
+
34
+ # Load user and item graphs
35
+ self.topK_users = dataset.topK_users
36
+ self.topK_items = dataset.topK_items
37
+ self.dict_user_co_occ_graph = dataset.dict_user_co_occ_graph
38
+ self.dict_item_co_occ_graph = dataset.dict_item_co_occ_graph
39
+ self.topK_users_counts = dataset.topK_users_counts
40
+ self.topK_items_counts = dataset.topK_items_counts
41
+
42
+ self.s_drop = config.s_drop
43
+ self.m_drop = config.m_drop
44
+ self.ly_norm = nn.LayerNorm(self.feat_embed_dim)
45
+
46
+ self.self_i_attn1 = nn.MultiheadAttention(1, 1, dropout=self.s_drop, batch_first=True)
47
+ self.self_i_attn2 = nn.MultiheadAttention(1, 1, dropout=self.s_drop, batch_first=True)
48
+
49
+ self.mutual_i_attn1 = nn.MultiheadAttention(1, 1, dropout=self.m_drop, batch_first=True)
50
+ self.mutual_i_attn2 = nn.MultiheadAttention(1, 1, dropout=self.m_drop, batch_first=True)
51
+
52
+ self.user_id_embedding = nn.Embedding(self.n_users, self.embedding_dim).to(self.device)
53
+ self.item_id_embedding = nn.Embedding(self.n_items, self.embedding_dim).to(self.device)
54
+
55
+ self.prl = nn.PReLU().to(self.device)
56
+
57
+ self.cal_bpr = torch.tensor([[1.0], [-1.0]]).to(self.device)
58
+
59
+ # load dataset info
60
+ self.norm_adj = get_norm_adj_mat(self, dataset.sparse_inter_matrix(form='coo')).to(self.device)
61
+ # Process to obtain user co-occurrence matrix (n_users*num_user_co)
62
+ self.user_co_graph = topk_sample(self.n_users, self.dict_user_co_occ_graph, self.num_user_co,
63
+ self.topK_users, self.topK_users_counts, 'softmax',
64
+ self.device)
65
+
66
+ # Process to obtain user co-occurrence matrix (n_users*num_user_co)
67
+ self.item_co_graph = topk_sample(self.n_items, self.dict_item_co_occ_graph, self.num_item_co,
68
+ self.topK_items, self.topK_items_counts, 'softmax',
69
+ self.device)
70
+
71
+ # Process to obtain item similarity matrix (n_items* n_items )
72
+ self.i_mm_adj = dataset.i_mm_adj
73
+ # Process to obtain user similarity matrix (n_users* n_users)
74
+ self.u_mm_adj = dataset.u_mm_adj
75
+
76
+ # Strengthen ii and uu graphs
77
+ self.stre_ii_graph = self.ii_co_weight * self.item_co_graph + (1.0 - self.ii_co_weight) * self.i_mm_adj
78
+ self.stre_uu_graph = self.uu_co_weight * self.user_co_graph + (1.0 - self.uu_co_weight) * self.u_mm_adj
79
+
80
+ if self.i_v_feat is not None:
81
+ self.image_embedding = nn.Embedding.from_pretrained(self.i_v_feat, freeze=False).to(self.device)
82
+ self.image_i_trs = nn.Linear(self.i_v_feat.shape[1], self.feat_embed_dim)
83
+ self.user_v_prefer = torch.nn.Parameter(dataset.u_v_interest, requires_grad=True).to(self.device)
84
+ self.image_u_trs = nn.Linear(self.i_v_feat.shape[1], self.feat_embed_dim)
85
+
86
+ if self.i_t_feat is not None:
87
+ self.text_embedding = nn.Embedding.from_pretrained(self.i_t_feat, freeze=False).to(self.device)
88
+ self.text_i_trs = nn.Linear(self.i_t_feat.shape[1], self.feat_embed_dim)
89
+ self.user_t_prefer = torch.nn.Parameter(dataset.u_t_interest, requires_grad=True).to(self.device)
90
+ self.text_u_trs = nn.Linear(self.i_t_feat.shape[1], self.feat_embed_dim)
91
+
92
+ # MLP(input_dim, feature_dim, hidden_dim, output_dim)
93
+ self.mlp_u1 = MLP(self.feat_embed_dim, self.feat_embed_dim * self.k, self.feat_embed_dim * self.k, self.device)
94
+ self.mlp_u2 = MLP(self.feat_embed_dim, self.feat_embed_dim * self.k, self.feat_embed_dim * self.k, self.device)
95
+ self.mlp_i1 = MLP(self.feat_embed_dim, self.feat_embed_dim * self.k, self.feat_embed_dim * self.k, self.device)
96
+ self.mlp_i2 = MLP(self.feat_embed_dim, self.feat_embed_dim * self.k, self.feat_embed_dim * self.k, self.device)
97
+ self.meta_netu = nn.Linear(self.feat_embed_dim * 2, self.feat_embed_dim, bias=True) # Knowledge compression
98
+ self.meta_neti = nn.Linear(self.feat_embed_dim * 2, self.feat_embed_dim, bias=True) # Knowledge compression
99
+
100
+ self._reset_parameters()
101
+
102
+ def _reset_parameters(self):
103
+ nn.init.normal_(self.user_id_embedding.weight, std=0.1)
104
+ nn.init.normal_(self.item_id_embedding.weight, std=0.1)
105
+
106
+ nn.init.xavier_normal_(self.image_i_trs.weight)
107
+ nn.init.xavier_normal_(self.text_i_trs.weight)
108
+ nn.init.xavier_normal_(self.image_u_trs.weight)
109
+ nn.init.xavier_normal_(self.text_u_trs.weight)
110
+
111
+ def forward(self):
112
+ # Uniform feature dimensions for multi-modal feature information on item
113
+ trs_item_v_feat = self.image_i_trs(self.image_embedding.weight)
114
+ trs_item_t_feat = self.text_i_trs(self.text_embedding.weight) # num_items * 64
115
+
116
+ trs_user_v_prefer = self.image_u_trs(self.user_v_prefer)
117
+ trs_user_t_prefer = self.text_u_trs(self.user_t_prefer) # num_items * 64
118
+
119
+ # ====================================================================================
120
+ # Homography Relation Learning
121
+ # ====================================================================================
122
+ # Item homogeneous relational learning
123
+ item_v_t = torch.cat((trs_item_v_feat, trs_item_t_feat), dim=-1)
124
+ item_id_v_t = torch.cat((self.item_id_embedding.weight, item_v_t), dim=-1)
125
+ item_id_v_t = propgt_info(item_id_v_t, self.n_ii_layers, self.stre_ii_graph, last_layer=True)
126
+ item_id_v_t = F.normalize(item_id_v_t)
127
+
128
+ item_id_ii = item_id_v_t[:, :self.embedding_dim]
129
+ gnn_i_v_feat = item_id_v_t[:, self.feat_embed_dim:-self.feat_embed_dim]
130
+ gnn_i_t_feat = item_id_v_t[:, -self.feat_embed_dim:]
131
+
132
+ # User homogeneous relational learning
133
+ user_v_t = torch.cat((trs_user_v_prefer, trs_user_t_prefer), dim=-1)
134
+ user_id_v_t = torch.cat((self.user_id_embedding.weight, user_v_t), dim=-1)
135
+ user_id_v_t = propgt_info(user_id_v_t, self.n_uu_layers, self.stre_uu_graph, last_layer=True)
136
+
137
+ user_id_v_t = F.normalize(user_id_v_t)
138
+ user_id_uu = user_id_v_t[:, :self.embedding_dim]
139
+ gnn_u_v_prefer = user_id_v_t[:, self.embedding_dim:-self.feat_embed_dim]
140
+ gnn_u_t_prefer = user_id_v_t[:, -self.feat_embed_dim:]
141
+
142
+ # ====================================================================================
143
+ # Item Feature Attention Integration
144
+ # ====================================================================================
145
+ # Item visual features self-attention
146
+ item_v_feat, _ = self.self_i_attn1(gnn_i_v_feat.unsqueeze(2), gnn_i_v_feat.unsqueeze(2),
147
+ gnn_i_v_feat.unsqueeze(2), need_weights=False)
148
+ item_v_feat = self.ly_norm(gnn_i_v_feat + item_v_feat.squeeze())
149
+ item_v_feat = self.prl(item_v_feat)
150
+
151
+ # Item text features self-attention
152
+ item_t_feat, _ = self.self_i_attn2(gnn_i_t_feat.unsqueeze(2), gnn_i_t_feat.unsqueeze(2),
153
+ gnn_i_t_feat.unsqueeze(2), need_weights=False)
154
+ item_t_feat = self.ly_norm(gnn_i_t_feat + item_t_feat.squeeze())
155
+ item_t_feat = self.prl(item_t_feat)
156
+
157
+ # ---------------------------------------------------------------------------------------
158
+ # Item text to visual cross-attention
159
+ i_t2v_feat, _ = self.mutual_i_attn1(item_t_feat.unsqueeze(2), item_v_feat.unsqueeze(2),
160
+ item_v_feat.unsqueeze(2), need_weights=False)
161
+ item_t2v_feat = self.ly_norm(item_v_feat + i_t2v_feat.squeeze())
162
+ item_t2v_feat = self.prl(item_t2v_feat)
163
+
164
+ # Item visual to text cross-attention
165
+ i_v2t_feat, _ = self.mutual_i_attn2(item_v_feat.unsqueeze(2), item_t_feat.unsqueeze(2),
166
+ item_t_feat.unsqueeze(2), need_weights=False)
167
+ item_v2t_feat = self.ly_norm(item_t_feat.squeeze() + i_v2t_feat.squeeze())
168
+ item_v2t_feat = self.prl(item_v2t_feat)
169
+
170
+ user_v_prefer = self.prl(gnn_u_v_prefer) # (num_items* 64)
171
+ user_t_prefer = self.prl(gnn_u_t_prefer)
172
+
173
+ # ====================================================================================
174
+ # Heterography Relation Learning
175
+ # ====================================================================================
176
+ # Item feature splicing with total attentions
177
+ item_v_t_feat = torch.cat((item_t2v_feat, item_v2t_feat), dim=-1) # (num_items* 128)
178
+ user_v_t_prefer = torch.cat((user_v_prefer, user_t_prefer), dim=-1) # (num_user* 128)
179
+ ego_feat_prefer = torch.cat((user_v_t_prefer, item_v_t_feat), dim=0) # (num_users+num_items)* 128)
180
+ self.fin_feat_prefer = propgt_info(ego_feat_prefer, self.n_layers, self.norm_adj)
181
+
182
+ ego_id_embed = torch.cat((user_id_uu, item_id_ii), dim=0) # (num_users+num_items)* 64)
183
+ fin_id_embed = propgt_info(ego_id_embed, self.n_layers, self.norm_adj)
184
+
185
+ share_knowldge = self.meta_extra_share(fin_id_embed, self.fin_feat_prefer) # (num_users+num_items)* 64)
186
+
187
+ fin_v = self.prl(self.fin_feat_prefer[:, :self.embedding_dim]) + fin_id_embed
188
+ fin_t = self.prl(self.fin_feat_prefer[:, self.embedding_dim:]) + fin_id_embed
189
+ fin_share = self.prl(share_knowldge) + fin_id_embed
190
+
191
+ temp_full_feat_prefer = torch.cat((fin_v, fin_t), dim=-1)
192
+ representation = torch.cat((temp_full_feat_prefer, fin_share), dim=-1)
193
+
194
+ return representation
195
+
196
+ def loss(self, user_tensor, item_tensor):
197
+ user_tensor_flatten = user_tensor.view(-1)
198
+ item_tensor_flatten = item_tensor.view(-1)
199
+ out = self.forward()
200
+ user_rep = out[user_tensor_flatten]
201
+ item_rep = out[item_tensor_flatten]
202
+
203
+ score = torch.sum(user_rep * item_rep, dim=1).view(-1, 2)
204
+ bpr_score = torch.matmul(score, self.cal_bpr)
205
+ bpr_loss = -torch.mean(nn.LogSigmoid()(bpr_score))
206
+
207
+ # Loss of multi-modal feature contrasts
208
+ i_mul_vt_cl_loss = ssl_loss(self.fin_feat_prefer[:, :self.feat_embed_dim],
209
+ self.fin_feat_prefer[:, -self.feat_embed_dim:], item_tensor_flatten, self.cl_tmp)
210
+ u_mul_vt_cl_loss = ssl_loss(self.fin_feat_prefer[:, :self.feat_embed_dim],
211
+ self.fin_feat_prefer[:, -self.feat_embed_dim:], user_tensor_flatten, self.cl_tmp)
212
+ mul_vt_cl_loss = self.cl_loss_weight * (i_mul_vt_cl_loss + u_mul_vt_cl_loss)
213
+
214
+ # Modal-unique orthogonal constraint
215
+ mul_i_diff_loss = cal_diff_loss(self.fin_feat_prefer, user_tensor, self.feat_embed_dim)
216
+ mul_u_diff_loss = cal_diff_loss(self.fin_feat_prefer, item_tensor, self.feat_embed_dim)
217
+ mul_diff_loss = self.diff_loss_weight * (mul_i_diff_loss + mul_u_diff_loss)
218
+
219
+ reg_loss = 0 # Realized in AdamW
220
+ total_loss = bpr_loss + reg_loss + mul_vt_cl_loss + mul_diff_loss
221
+
222
+ return total_loss, bpr_loss, reg_loss, mul_vt_cl_loss, mul_diff_loss
223
+
224
+ def full_sort_predict(self, interaction):
225
+ user = interaction[0]
226
+ representation = self.forward()
227
+ u_reps, i_reps = torch.split(representation, [self.n_users, self.n_items], dim=0)
228
+ score_mat_ui = torch.matmul(u_reps[user], i_reps.t())
229
+ return score_mat_ui
230
+
231
+ def meta_extra_share(self, id_embed, prefer_or_feat):
232
+ u_id_embed = id_embed[:self.n_users, :]
233
+ i_id_embed = id_embed[self.n_users:, :]
234
+
235
+ u_v_t = prefer_or_feat[:self.n_users, :]
236
+ i_v_t = prefer_or_feat[self.n_users:, :]
237
+
238
+ # meta-knowlege extraction
239
+ u_knowldge = self.meta_netu(u_v_t).detach()
240
+ i_knowldge = self.meta_neti(i_v_t).detach()
241
+
242
+ """ Personalized transformation parameter matrix """
243
+ # Low rank matrix decomposition
244
+ metau1 = self.mlp_u1(u_knowldge).reshape(-1, self.feat_embed_dim, self.k) # N_u*d*k [19445, 64, 3]
245
+ metau2 = self.mlp_u2(u_knowldge).reshape(-1, self.k, self.feat_embed_dim) # N_u*k*d [19445, 3, 64]
246
+ metai1 = self.mlp_i1(i_knowldge).reshape(-1, self.feat_embed_dim, self.k) # N_i*d*k [7050, 64, 3]
247
+ metai2 = self.mlp_i2(i_knowldge).reshape(-1, self.k, self.feat_embed_dim) # N_i*k*d [7050, 3,64]
248
+ meta_biasu = torch.mean(metau1, dim=0) # d*k [64, 3]
249
+ meta_biasu1 = torch.mean(metau2, dim=0) # k*d [3,64]
250
+ meta_biasi = torch.mean(metai1, dim=0) # [64, 3]
251
+ meta_biasi1 = torch.mean(metai2, dim=0) # [3, 64]
252
+ low_weightu1 = F.softmax(metau1 + meta_biasu, dim=1)
253
+ low_weightu2 = F.softmax(metau2 + meta_biasu1, dim=1)
254
+ low_weighti1 = F.softmax(metai1 + meta_biasi, dim=1)
255
+ low_weighti2 = F.softmax(metai2 + meta_biasi1, dim=1)
256
+
257
+ # The learned matrix as the weights of the transformed network Equal to a two-layer linear network;
258
+ u_middle_knowldge = torch.sum(torch.multiply(u_id_embed.unsqueeze(-1), low_weightu1), dim=1)
259
+ u_share_knowldge = torch.sum(torch.multiply(u_middle_knowldge.unsqueeze(-1), low_weightu2), dim=1)
260
+ i_middle_knowldge = torch.sum(torch.multiply(i_id_embed.unsqueeze(-1), low_weighti1), dim=1)
261
+ i_share_knowldge = torch.sum(torch.multiply(i_middle_knowldge.unsqueeze(-1), low_weighti2), dim=1)
262
+
263
+ share_knowldge = torch.cat((u_share_knowldge, i_share_knowldge), dim=0)
264
+ return share_knowldge
265
+
266
+
267
+ class MLP(torch.nn.Module):
268
+ def __init__(self, input_dim, feature_dim, output_dim, device):
269
+ super(MLP, self).__init__()
270
+ self.device = device
271
+ self.linear_pre = nn.Linear(input_dim, feature_dim, bias=True)
272
+ self.prl = nn.PReLU().to(self.device)
273
+ self.linear_out = nn.Linear(feature_dim, output_dim, bias=True)
274
+
275
+ def forward(self, data):
276
+ x = self.prl(self.linear_pre(data))
277
+ x = self.linear_out(x)
278
+ x = F.normalize(x, p=2, dim=-1)
279
+ return x
trainer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from tqdm import tqdm
4
+ from torch.nn.utils.clip_grad import clip_grad_norm_
5
+
6
+
7
+ # def train(length, epoch, dataloader, model, optimizer, batch_size, writer=None):
8
+ def train(self, epoch_idx):
9
+ self.model.train()
10
+ sum_loss = 0.0
11
+ sum_bpr_loss, sum_reg_loss = 0.0, 0.0
12
+ sum_diff_loss, sum_mul_vt_cl_loss = 0.0, 0.0
13
+
14
+ step = 0.0
15
+ # bar = tqdm(total=len(self.train_dataset))
16
+ # num_bar = 0 self_vt_cl_loss, mul_vt_cl_loss
17
+ for batch_idx, interactions in enumerate(self.train_data):
18
+ self.optimizer.zero_grad()
19
+ loss, bpr_loss, reg_loss, mul_vt_cl_loss, diff_loss = self.model.loss(interactions[0],
20
+ interactions[1])
21
+ if torch.isnan(loss):
22
+ self.logger.info('Loss is nan at epoch: {}, batch index: {}. Exiting.'.format(epoch_idx, batch_idx))
23
+ return loss, torch.tensor(0.0)
24
+
25
+ loss.backward()
26
+ self.optimizer.step()
27
+
28
+ step += 1.0
29
+ sum_loss += loss
30
+ sum_bpr_loss += bpr_loss
31
+ sum_reg_loss += reg_loss
32
+ sum_mul_vt_cl_loss += mul_vt_cl_loss
33
+ sum_diff_loss += diff_loss
34
+ mean_loss = sum_loss / step
35
+ mean_bpr_loss = sum_bpr_loss / step
36
+ mean_reg_loss = sum_reg_loss / step
37
+ mean_mul_vt_cl_loss = sum_mul_vt_cl_loss / step
38
+ mean_diff_loss = sum_diff_loss / step
39
+
40
+ if self.writer is not None:
41
+ self.writer.add_scalar('loss/train', mean_loss, epoch_idx)
42
+ self.writer.add_scalar('loss/bpr_loss', mean_bpr_loss, epoch_idx)
43
+ self.writer.add_scalar('loss/reg_loss', mean_reg_loss, epoch_idx)
44
+ self.writer.add_scalar('loss/mul_vt_cl_loss', mean_mul_vt_cl_loss, epoch_idx)
45
+ self.writer.add_scalar('loss/diff_loss', mean_diff_loss, epoch_idx)
46
+
47
+ # bar.close()
48
+ return [mean_loss, mean_bpr_loss, mean_reg_loss, mean_mul_vt_cl_loss, mean_diff_loss]
utils/configurator.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import multiprocessing
4
+ from utils.helper import init_seed
5
+
6
+
7
+ class Config(object):
8
+ def __init__(self, args):
9
+ self.model_name = args.model_name
10
+ self.dataset = args.dataset
11
+ self.learning_rate = args.l_r
12
+ self.learning_rate_scheduler = args.learning_rate_scheduler
13
+ self.embedding_dim = args.embedding_dim
14
+ self.num_epoch = args.num_epoch
15
+ self.reg_weight = args.reg_weight
16
+ self.use_gpu = args.use_gpu
17
+ self.gpu_id = args.gpu_id
18
+ self.seed = args.seed
19
+
20
+ self.batch_size = args.batch_size
21
+ self.eval_batch_size = args.eval_batch_size
22
+ self.topk = args.topk
23
+ self.valid_metric = args.valid_metric
24
+ self.metrics = args.metrics
25
+ self.stopping_step = args.stopping_step
26
+ self.n_layers = args.num_layer
27
+
28
+ self.rank = args.rank
29
+ self.s_drop = args.s_drop
30
+ self.m_drop = args.m_drop
31
+ self.cl_tmp = args.cl_tmp
32
+ self.item_knn_k = args.item_knn_k
33
+ self.user_knn_k = args.user_knn_k
34
+ self.num_user_co = args.user_knn_k # same as user_knn_k to compute
35
+ self.num_item_co = args.item_knn_k # same as item_knn_k to compute
36
+ self.n_ii_layers = args.n_ii_layers
37
+ self.n_uu_layers = args.n_uu_layers
38
+ self.writer = args.with_tensorboard
39
+ self.uu_co_weight = args.uu_co_weight
40
+ self.ii_co_weight = args.ii_co_weight
41
+ self.cl_loss_weight = args.cl_loss_weight
42
+ self.user_aggr_mode = args.user_aggr_mode
43
+ self.i_mm_image_weight = args.i_mm_image_weight
44
+ self.u_mm_image_weight = args.u_mm_image_weight
45
+ self.diff_loss_weight = args.diff_loss_weight
46
+
47
+ self._init_device(args)
48
+ init_seed(self.seed)
49
+
50
+ def _init_device(self, args):
51
+ if self.use_gpu:
52
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
53
+ self.device = torch.device("cuda" if torch.cuda.is_available() and self.use_gpu else "cpu")
54
+
55
+ # Ensure that setting up multiple threads does not exceed
56
+ max_cpu_count = multiprocessing.cpu_count()
57
+ self.num_workers = max_cpu_count // 2 if max_cpu_count // 2 < args.num_workers else args.num_workers
58
+
59
+ def __str__(self):
60
+ args_info = '\nModel arguments: '
61
+ args_info += ',\n'.join(["{} = {}".format(arg, value) for arg, value in self.__dict__.items()])
62
+ args_info += '.\n'
63
+ return args_info
utils/data_loader.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from utils.helper import *
3
+ from logging import getLogger
4
+ from collections import defaultdict
5
+ import torch.sparse as tsp
6
+
7
+
8
+ class BaseDataset(object):
9
+ def __init__(self, config):
10
+ self.config = config
11
+ self.logger = getLogger("normal")
12
+ self.device = config.device
13
+ self.dataset_name = config.dataset
14
+ self.load_all_data()
15
+ self.processed_eval_data()
16
+ self.n_users = len(set(self.train_data[:, 0]) | set(self.valid_data[:, 0]) | set(self.test_data[:, 0]))
17
+ self.n_items = len(set(self.train_data[:, 1]) | set(self.valid_data[:, 1]) | set(self.test_data[:, 1]))
18
+ self.train_data[:, 1] += self.n_users # Ensure that the ids are different
19
+ self.valid_data[:, 1] += self.n_users # Ensure that the ids are different
20
+ self.test_data[:, 1] += self.n_users # Ensure that the ids are different
21
+ self.dict_user_items()
22
+
23
+ def load_all_data(self):
24
+ dataset_path = str('./data/' + self.dataset_name)
25
+ self.train_dataset = np.load(dataset_path + '/train.npy', allow_pickle=True) # [[1,2,3],[2,3,0]]
26
+ v_feat = np.load(dataset_path + '/image_feat.npy', allow_pickle=True)
27
+ self.i_v_feat = torch.from_numpy(v_feat).type(torch.FloatTensor).to(self.device) # 4096
28
+ t_feat = np.load(dataset_path + '/text_feat.npy', allow_pickle=True)
29
+ self.i_t_feat = torch.from_numpy(t_feat).type(torch.FloatTensor).to(self.device) # 384
30
+ self.valid_dataset = np.load(dataset_path + '/valid.npy', allow_pickle=True) # [[1,2,3],[2,3,0]]
31
+ self.test_dataset = np.load(dataset_path + '/test.npy', allow_pickle=True) # [[1,2,3],[2,3,0]]
32
+
33
+ def processed_eval_data(self):
34
+ self.train_data = self.train_dataset.transpose(1, 0).copy()
35
+ self.valid_data = self.valid_dataset.transpose(1, 0).copy()
36
+ self.test_data = self.test_dataset.transpose(1, 0).copy()
37
+
38
+ def load_eval_data(self):
39
+ return self.valid_data, self.test_data
40
+
41
+ def dict_user_items(self):
42
+ self.dict_train_u_i = update_dict("user", self.train_data, defaultdict(set))
43
+ self.dict_train_i_u = update_dict("item", self.train_data, defaultdict(set))
44
+ tmp_dict_u_i = update_dict("user", self.valid_data, self.dict_train_u_i)
45
+ self.user_items_dict = update_dict("user", self.test_data, tmp_dict_u_i)
46
+
47
+ # Process out the most interacted users
48
+ # (first sort by the number of users interacting with the user, and finally return the user values in descending order)
49
+ sort_itme_num = sorted(self.dict_train_u_i.items(), key=lambda item: len(item[1]), reverse=True)
50
+ self.topK_users = [temp[0] for temp in sort_itme_num]
51
+ self.topK_users_counts = [len(temp[1]) for temp in sort_itme_num]
52
+ # Process out the most interacted items
53
+ # (first sort by the number of users interacting with the item, and finally return the item values in descending order)
54
+ sort_user_num = sorted(self.dict_train_i_u.items(), key=lambda item: len(item[1]), reverse=True)
55
+ self.topK_items = [temp[0] - self.n_users for temp in sort_user_num] # Guaranteed from 0
56
+ self.topK_items_counts = [len(temp[1]) for temp in sort_user_num]
57
+
58
+ def sparse_inter_matrix(self, form):
59
+ return cal_sparse_inter_matrix(self, form)
60
+
61
+ def log_info(self, name, interactions, list_u, list_i):
62
+ info = [self.dataset_name]
63
+ inter_num = len(interactions)
64
+ num_u = len(set(list_u))
65
+ num_i = len(set(list_i))
66
+ info.extend(['The number of users: {}'.format(num_u),
67
+ 'Average actions of users: {}'.format(inter_num / num_u)])
68
+ info.extend(['The number of items: {}'.format(num_i),
69
+ 'Average actions of items: {}'.format(inter_num / num_i)])
70
+ info.append('The number of inters: {}'.format(inter_num))
71
+ sparsity = 1 - inter_num / num_u / num_i
72
+ info.append('The sparsity of the dataset: {}%'.format(sparsity * 100))
73
+ self.logger.info('\n====' + name + '====\n' + str('\n'.join(info)))
74
+
75
+
76
+ class Load_dataset(BaseDataset):
77
+ def __init__(self, config):
78
+ super().__init__(config)
79
+ self.item_knn_k = config.item_knn_k
80
+ self.user_knn_k = config.user_knn_k
81
+ self.i_mm_image_weight = config.i_mm_image_weight
82
+ self.u_mm_image_weight = config.u_mm_image_weight
83
+ self.all_set = set(range(self.n_users, self.n_users + self.n_items))
84
+ # Print statistical information
85
+ self.log_info("Training", self.train_data, self.train_data[:, 0], self.train_data[:, 1])
86
+
87
+ # ***************************************************************************************
88
+ # Prepare four graphs that will be needed later
89
+ # (user co-occurrence graph, user interest graph, item co-occurrence graph, item semantic graph)
90
+
91
+ # Construct a user co-occurrence matrix with several items of common interaction between all users
92
+ self.user_co_occ_matrix = load_or_create_matrix(self.logger, "User", " co-occurrence matrix",
93
+ self.dataset_name, "user_co_occ_matrix", creat_co_occur_matrix,
94
+ "user", self.train_data, 0, self.n_users)
95
+ # Construct an item co-occurrence matrix with several users who interact in common between all items
96
+ self.item_co_occ_matrix = load_or_create_matrix(self.logger, "Item", " co-occurrence matrix",
97
+ self.dataset_name, "item_co_occ_matrix", creat_co_occur_matrix,
98
+ "item", self.train_data, self.n_users, self.n_items)
99
+
100
+ # Construct a dictionary of user graphs, taking the first 200
101
+ self.dict_user_co_occ_graph = load_or_create_matrix(self.logger, "User", " co-occurrence dict graph",
102
+ self.dataset_name, "dict_user_co_occ_graph",
103
+ creat_dict_graph,
104
+ self.user_co_occ_matrix, self.n_users)
105
+ # Construct a dictionary of item graphs, taking the first 200
106
+ self.dict_item_co_occ_graph = load_or_create_matrix(self.logger, "Item", " co-occurrence dict graph",
107
+ self.dataset_name, "dict_item_co_occ_graph",
108
+ creat_dict_graph,
109
+ self.item_co_occ_matrix, self.n_items)
110
+ # ***************************************************************************************
111
+
112
+ # Get the sparse interaction matrix of the training set
113
+ sp_inter_m = sparse_mx_to_torch_sparse_tensor(self.sparse_inter_matrix(form='coo')).to(self.device)
114
+ # Construct a item weight graph
115
+ if self.i_v_feat is not None: # 4096
116
+ # Construct user visual interest similarity graphs
117
+ self.u_v_interest = tsp.mm(sp_inter_m, self.i_v_feat) / tsp.sum(sp_inter_m, [1]).unsqueeze(dim=1).to_dense()
118
+ u_v_adj = get_knn_adj_mat(self.u_v_interest, self.user_knn_k, self.device)
119
+ i_v_adj = get_knn_adj_mat(self.i_v_feat, self.item_knn_k, self.device)
120
+ self.i_mm_adj = i_v_adj
121
+ self.u_mm_adj = u_v_adj
122
+ if self.i_t_feat is not None: # 384
123
+ # Construct a user text interest similarity graph
124
+ self.u_t_interest = tsp.mm(sp_inter_m, self.i_t_feat) / tsp.sum(sp_inter_m, [1]).unsqueeze(dim=1).to_dense()
125
+ u_t_adj = get_knn_adj_mat(self.u_t_interest, self.user_knn_k, self.device)
126
+ i_t_adj = get_knn_adj_mat(self.i_t_feat, self.item_knn_k, self.device)
127
+ self.i_mm_adj = i_t_adj
128
+ self.u_mm_adj = u_t_adj
129
+ if self.i_v_feat is not None and self.i_t_feat is not None:
130
+ self.i_mm_adj = self.i_mm_image_weight * i_v_adj + (1.0 - self.i_mm_image_weight) * i_t_adj
131
+ self.u_mm_adj = self.u_mm_image_weight * u_v_adj + (1.0 - self.u_mm_image_weight) * u_t_adj
132
+ del i_t_adj, i_v_adj, u_t_adj, u_v_adj
133
+ torch.cuda.empty_cache()
134
+
135
+ # ***************************************************************************************
136
+ def __len__(self):
137
+ return len(self.train_data)
138
+
139
+ def __getitem__(self, index):
140
+ user, pos_item = self.train_data[index]
141
+ neg_item = random.sample(self.all_set - set(self.user_items_dict[user]), 1)[0]
142
+ return torch.LongTensor([user, user]), torch.LongTensor([pos_item, neg_item])
143
+
144
+
145
+ class Load_eval_dataset(BaseDataset):
146
+ def __init__(self, v_or_t, config, eval_dataset):
147
+ super().__init__(config)
148
+ self.eval_dataset = eval_dataset
149
+ self.step = config.eval_batch_size
150
+ self.inter_pr = 0 # Markup of the number of interactions that have been computed
151
+ self.eval_items_per_u = []
152
+ self.eval_len_list = []
153
+ self.train_pos_len_list = []
154
+ self.eval_u = list(set(eval_dataset[:, 0])) # Total users index
155
+ self.t_data = self.train_data
156
+ self.pos_items_per_u = self.train_items_per_u(self.eval_u)
157
+ self.evalute_items_per_u(self.eval_u)
158
+
159
+ self.s_idx = 0 # eval start index s_idx=pr
160
+
161
+ self.eval_users = len(set(eval_dataset[:, 0]))
162
+ self.eval_items = len(set(eval_dataset[:, 1]))
163
+
164
+ self.n_inters = eval_dataset.shape[0] # num_interactions n_inters=pr_end
165
+ # Print statistical information
166
+ self.log_info(v_or_t, self.eval_dataset, eval_dataset[:, 0], eval_dataset[:, 1])
167
+
168
+ def __len__(self):
169
+ return math.ceil(self.n_inters / self.step)
170
+
171
+ def __iter__(self):
172
+ return self
173
+
174
+ def __next__(self):
175
+ if self.s_idx >= self.n_inters:
176
+ self.s_idx = 0
177
+ self.inter_pr = 0
178
+ raise StopIteration()
179
+ return self._next_batch_data()
180
+
181
+ def _next_batch_data(self):
182
+ # Calculate the total number of interactions between the training set from A to B
183
+ inter_cnt = sum(self.train_pos_len_list[self.s_idx: self.s_idx + self.step])
184
+ batch_users = self.eval_u[self.s_idx: self.s_idx + self.step]
185
+ batch_mask_matrix = self.pos_items_per_u[:, self.inter_pr: self.inter_pr + inter_cnt].clone()
186
+ # user_ids to index(Always keep the index value at 0-self.step in preparation for evaluating the mask later on)
187
+ batch_mask_matrix[0] -= self.s_idx
188
+ self.inter_pr += inter_cnt # Update the starting index of the fetch data interaction data
189
+ self.s_idx += self.step # Update the starting index of the fetching user before fetching the data interaction data
190
+
191
+ return [batch_users, batch_mask_matrix]
192
+
193
+ def train_items_per_u(self, eval_users):
194
+ u_ids, i_ids = list(), list()
195
+ for i, u in enumerate(eval_users):
196
+ # Search for the number of items the training set has interacted with in order
197
+ u_ls = self.t_data[np.where(self.t_data[:, 0] == u), 1][0]
198
+ i_len = len(u_ls)
199
+ self.train_pos_len_list.append(i_len)
200
+ u_ids.extend([i] * i_len)
201
+ i_ids.extend(u_ls)
202
+ return torch.tensor([u_ids, i_ids]).type(torch.LongTensor)
203
+
204
+ def evalute_items_per_u(self, eval_users):
205
+ for u in eval_users:
206
+ u_ls = self.eval_dataset[np.where(self.eval_dataset[:, 0] == u), 1][0]
207
+ self.eval_len_list.append(len(u_ls))
208
+ self.eval_items_per_u.append(u_ls - self.n_users) # Items per user interaction
209
+ self.eval_len_list = np.asarray(self.eval_len_list)
utils/evaluator.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from utils.metrics import metrics_dict
4
+
5
+
6
+ def evaluate_model(self, epoch, eval_data, t_or_v):
7
+ self.model.eval()
8
+ with torch.no_grad():
9
+ batch_matrix_list = []
10
+ for batch_idx, batched_data in enumerate(eval_data):
11
+ scores = self.model.full_sort_predict(batched_data)
12
+ masked_items = batched_data[1]
13
+ scores[masked_items[0], masked_items[1] - self.model.n_users] = -1e10 # mask out pos items,restore ori_id
14
+ _, top_k_index = torch.topk(scores, max(self.topk), dim=-1) # nusers x topk
15
+ batch_matrix_list.append(top_k_index)
16
+
17
+ pos_items = eval_data.eval_items_per_u
18
+ pos_len_list = eval_data.eval_len_list
19
+ top_k_index = torch.cat(batch_matrix_list, dim=0).cpu().numpy()
20
+ assert len(pos_len_list) == len(top_k_index)
21
+ bool_rec_matrix = []
22
+
23
+ for m, n in zip(pos_items, top_k_index):
24
+ bool_rec_matrix.append([True if i in m else False for i in n])
25
+ bool_rec_matrix = np.asarray(bool_rec_matrix)
26
+
27
+ # get metrics
28
+ metric_dict = {}
29
+ result_list = cal_metrics(self.metrics, pos_len_list, bool_rec_matrix)
30
+ list_key = []
31
+ for metric, value in zip(self.metrics, result_list):
32
+ for k in self.topk:
33
+ key = '{}@{}'.format(metric, k)
34
+ list_key.append(key) if k == self.topk[-1] else None
35
+ metric_dict[key] = round(value[k - 1], 4) # Round to 4 decimal points
36
+ valid_score = metric_dict[self.valid_metric] if self.valid_metric else metric_dict['NDCG@20']
37
+ if self.writer is not None:
38
+ for idx in list_key:
39
+ self.writer.add_scalar(t_or_v + "_" + idx, metric_dict[idx], epoch) # Precision@20,Recall@20,NDCG@20
40
+ self.writer.add_histogram(t_or_v + '_user_visual_distribution', self.model.user_v_prefer, epoch)
41
+ self.writer.add_histogram(t_or_v + '_user_textual_distribution', self.model.user_t_prefer, epoch)
42
+ self.writer.add_embedding(self.model.user_id_embedding.weight, global_step=epoch,
43
+ tag=t_or_v + "user_id_embedding")
44
+ self.writer.add_embedding(self.model.item_id_embedding.weight, global_step=epoch,
45
+ tag=t_or_v + "item_id_embedding")
46
+
47
+ return valid_score, metric_dict
48
+
49
+
50
+ def cal_metrics(topk_metrics, pos_len_list, topk_index):
51
+ result_list = []
52
+ for metric in topk_metrics:
53
+ metric_fuc = metrics_dict[metric]
54
+ result = metric_fuc(topk_index, pos_len_list)
55
+ result_list.append(result)
56
+ return np.stack(result_list, axis=0)
utils/helper.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import datetime
5
+ import numpy as np
6
+ import scipy.sparse as sp
7
+ import matplotlib.pyplot as plt
8
+ import torch.nn.functional as F
9
+ from tqdm import tqdm
10
+ from time import time
11
+ from collections import defaultdict
12
+ from scipy.sparse import coo_matrix
13
+ from torch.nn.functional import cosine_similarity
14
+
15
+
16
+ def update_dict(key_ui, dataset, edge_dict):
17
+ for edge in dataset:
18
+ user, item = edge
19
+ edge_dict[user].add(item) if key_ui == "user" else None
20
+ edge_dict[item].add(user) if key_ui == "item" else None
21
+ return edge_dict
22
+
23
+
24
+ def get_local_time():
25
+ return datetime.datetime.now().strftime('%b-%d-%Y-%H-%M-%S')
26
+
27
+
28
+ def cal_reg_loss(cal_embedding):
29
+ return (cal_embedding.norm(2).pow(2)) / cal_embedding.size()[0]
30
+
31
+
32
+ def cal_cos_loss(user, item):
33
+ return 1 - cosine_similarity(user, item, dim=-1).mean()
34
+
35
+
36
+ def init_seed(seed):
37
+ if torch.cuda.is_available():
38
+ torch.cuda.manual_seed(seed)
39
+ torch.cuda.manual_seed_all(seed)
40
+ random.seed(seed)
41
+ np.random.seed(seed)
42
+ torch.manual_seed(seed)
43
+ torch.backends.cudnn.benchmark = False
44
+ torch.backends.cudnn.deterministic = True
45
+
46
+
47
+ def early_stopping(value, best, cur_step, max_step):
48
+ stop_flag = False
49
+ update_flag = False
50
+
51
+ if value > best:
52
+ cur_step = 0
53
+ best = value
54
+ update_flag = True
55
+ else:
56
+ cur_step += 1
57
+ if cur_step > max_step:
58
+ stop_flag = True
59
+
60
+ return best, cur_step, stop_flag, update_flag
61
+
62
+
63
+ def dict2str(result_dict):
64
+ result_str = ''
65
+ for metric, value in result_dict.items():
66
+ result_str += str(metric) + ': ' + '%.04f' % value + ' '
67
+ return result_str
68
+
69
+
70
+ def res_output(epoch_idx, s_time, e_time, res, t_or_v):
71
+ _output = '\n epoch %d ' % epoch_idx + t_or_v + 'ing [time: %.2fs], ' % (e_time - s_time)
72
+ if t_or_v == "train":
73
+ _output += 'total_loss: {:.4f}, bpr_loss: {:.4f}, reg_loss:{:.4f}, mul_vt_cl_loss: {:.4f}, diff_loss: {:.4f}'.format(
74
+ res[0], res[1], res[2], res[3], res[4])
75
+ elif t_or_v == "valid":
76
+ _output += ' valid result: \n' + dict2str(res)
77
+ else:
78
+ _output += ' test result: \n' + dict2str(res)
79
+ return _output
80
+
81
+
82
+ def get_parameter_number(self):
83
+ self.logger.info(self.model)
84
+ # Print the number of model parameters
85
+ total_num = sum(p.numel() for p in self.model.parameters())
86
+ trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
87
+ self.logger.info('Total parameters: {}, Trainable parameters: {}'.format(total_num, trainable_num))
88
+
89
+
90
+ def get_norm_adj_mat(self, interaction_matrix):
91
+ adj_size = (self.n_users + self.n_items, self.n_users + self.n_items)
92
+ A = sp.dok_matrix(adj_size, dtype=np.float32)
93
+ inter_M = interaction_matrix
94
+ inter_M_t = interaction_matrix.transpose()
95
+ # {(userID,itemID):1,(userID,itemID):1}
96
+ data_dict = dict(zip(zip(inter_M.row, inter_M.col + self.n_users), [1] * inter_M.nnz))
97
+ data_dict.update(dict(zip(zip(inter_M_t.row + self.n_users, inter_M_t.col),
98
+ [1] * inter_M_t.nnz)))
99
+ A._update(data_dict) # Update to (n_users+n_items)*(n_users+n_items) sparse matrix
100
+ adj = sparse_mx_to_torch_sparse_tensor(A).to(self.device)
101
+ return torch_sparse_tensor_norm_adj(adj, adj, adj_size, self.device)
102
+
103
+
104
+ def cal_sparse_inter_matrix(self, form='coo'):
105
+ src = self.train_dataset[0, :]
106
+ tgt = self.train_dataset[1, :]
107
+ data = np.ones(len(self.train_dataset.transpose(1, 0)))
108
+ mat = coo_matrix((data, (src, tgt)), shape=(self.n_users, self.n_items))
109
+
110
+ if form == 'coo':
111
+ return mat
112
+ elif form == 'csr':
113
+ return mat.tocsr()
114
+ else:
115
+ raise NotImplementedError('sparse matrix format [{}] has not been implemented.'.format(form))
116
+
117
+
118
+ def ssl_loss(data1, data2, index, ssl_temp):
119
+ index = torch.unique(index)
120
+ embeddings1 = data1[index]
121
+ embeddings2 = data2[index]
122
+ norm_embeddings1 = F.normalize(embeddings1, p=2, dim=1)
123
+ norm_embeddings2 = F.normalize(embeddings2, p=2, dim=1)
124
+ pos_score_t = torch.sum(torch.mul(norm_embeddings1, norm_embeddings2), dim=1)
125
+ all_score = torch.mm(norm_embeddings1, norm_embeddings2.T)
126
+ pos_score = torch.exp(pos_score_t / ssl_temp)
127
+ all_score = torch.sum(torch.exp(all_score / ssl_temp), dim=1)
128
+ loss = (-torch.sum(torch.log(pos_score / all_score)) / (len(index)))
129
+ return loss
130
+
131
+
132
+ def cal_diff_loss(feat, ui_index, dim):
133
+ """
134
+ :param feat: uv_ut iv_it (n_users+n_items)*dim*2
135
+ :param ui_index: user or item index
136
+ :return: Squared Frobenius Norm Loss
137
+ """
138
+ input1 = feat[ui_index[:, 0], :dim]
139
+ input2 = feat[ui_index[:, 0], dim:]
140
+
141
+ # Zero mean
142
+ input1_mean = torch.mean(input1, dim=0, keepdims=True)
143
+ input2_mean = torch.mean(input2, dim=0, keepdims=True)
144
+
145
+ input1 = input1 - input1_mean
146
+ input2 = input2 - input2_mean
147
+
148
+ input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True).detach()
149
+ input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6)
150
+
151
+ input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True).detach()
152
+ input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6)
153
+
154
+ loss = torch.mean((input1_l2.t().mm(input2_l2)).pow(2))
155
+
156
+ return loss
157
+
158
+
159
+ def sele_para(config):
160
+ res = "\n *****************************************************************"
161
+ res += "***************************************************************** \n"
162
+ res += "l_r: " + str(config.learning_rate) + ", reg_w: " + str(config.reg_weight)
163
+ res += ", n_l: " + str(config.n_layers) + ", emb_dim: " + str(config.embedding_dim)
164
+ res += ", s_drop : " + str(config.s_drop) + ", m_drop : " + str(config.m_drop)
165
+ res += ", u_mm_v_w: " + str(config.u_mm_image_weight) + ", i_mm_v_w: " + str(config.i_mm_image_weight)
166
+ res += ", uu_co_w: " + str(config.uu_co_weight) + ", ii_co_w: " + str(config.ii_co_weight)
167
+ res += ", u_knn_k: " + str(config.user_knn_k) + ", i_knn_k: " + str(config.item_knn_k)
168
+ res += ", n_uu_layers: " + str(config.n_uu_layers) + ", n_ii_layers: " + str(config.n_ii_layers)
169
+ res += ", cl_temp: " + str(config.cl_tmp) + ", rank: " + str(config.rank)
170
+ res += ", cl_loss_w: " + str(config.cl_loss_weight) + ", diff_loss_w: " + str(config.diff_loss_weight)
171
+ return res + "\n"
172
+
173
+
174
+ def update_result(self, test_result):
175
+ update_output = ' 🏃 🏃 🏃 🏆🏆🏆 🏃 🏃 🏃 ' + self.model_name + "_" + self.dataset_name + '--Best validation results updated!!!'
176
+ self.logger.info(update_output)
177
+ self.best_test_upon_valid = test_result
178
+
179
+
180
+ def stop_log(self, epoch_idx, run_start_time):
181
+ stop_output = 'Finished training, best eval result in epoch %d' % (epoch_idx - self.cur_step)
182
+ stop_output += "\n [total time: %.2fmins], " % ((time() - run_start_time) / 60)
183
+ stop_output += '\n ' + str(self.config.dataset) + ' key parameter: ' + sele_para(self.config)
184
+ stop_output += 'test result: \n' + dict2str(self.best_test_upon_valid)
185
+ self.logger.info(stop_output)
186
+
187
+
188
+ def plot_curve(self, show=True, save_path=None):
189
+ epochs = list(self.train_loss_dict.keys())
190
+ epochs.sort()
191
+ train_loss_values = [float(self.train_loss_dict[epoch]) for epoch in epochs]
192
+ valid_result_values = [float(self.best_valid_result[epoch]) for epoch in epochs]
193
+ plt.plot(epochs, train_loss_values, label='train', color='red')
194
+ plt.plot(epochs, valid_result_values, label='valid', color='black')
195
+ plt.xticks(epochs)
196
+ plt.xlabel('Epoch')
197
+ plt.ylabel('Loss')
198
+ plt.title('Training loss and Validing result curves')
199
+ if show:
200
+ plt.show()
201
+ if save_path:
202
+ plt.savefig(save_path)
203
+
204
+
205
+ # Generated user or item co-occurrence matrix
206
+ def creat_co_occur_matrix(type_ui, all_edge, start_ui, num_ui):
207
+ """
208
+ :param type_ui: Types of created co-occurrence graphs, {user, item}
209
+ :param all_edge: train data np.array([[0, 6], [0, 11], [0, 8], [1, 7]])
210
+ :param start_ui: Minimum or starting user or item index
211
+ :param num_ui:Total number of users or items
212
+ :return:Generated user or item co-occurrence matrix
213
+ """
214
+ edge_dict = defaultdict(set)
215
+
216
+ for edge in all_edge:
217
+ user, item = edge
218
+ edge_dict[user].add(item) if type_ui == "user" else None
219
+ edge_dict[item].add(user) if type_ui == "item" else None
220
+
221
+ co_graph_matrix = torch.zeros(num_ui, num_ui)
222
+ key_list = sorted(list(edge_dict.keys()))
223
+ bar = tqdm(total=len(key_list))
224
+ for head in range(len(key_list)):
225
+ bar.update(1)
226
+ for rear in range(head + 1, len(key_list)):
227
+ head_key = key_list[head]
228
+ rear_key = key_list[rear]
229
+ ui_head = edge_dict[head_key]
230
+ ui_rear = edge_dict[rear_key]
231
+ inter_len = len(ui_head.intersection(ui_rear))
232
+ if inter_len > 0:
233
+ co_graph_matrix[head_key - start_ui][rear_key - start_ui] = inter_len
234
+ co_graph_matrix[rear_key - start_ui][head_key - start_ui] = inter_len
235
+ bar.close()
236
+ return co_graph_matrix
237
+
238
+
239
+ def creat_dict_graph(co_graph_matrix, num_ui):
240
+ dict_graph = {}
241
+ for i in tqdm(range(num_ui)):
242
+ num_co_ui = len(torch.nonzero(co_graph_matrix[i]))
243
+
244
+ if num_co_ui <= 200:
245
+ topk_ui = torch.topk(co_graph_matrix[i], num_co_ui)
246
+ edge_list_i = topk_ui.indices.tolist()
247
+ edge_list_j = topk_ui.values.tolist()
248
+ edge_list = [edge_list_i, edge_list_j]
249
+ dict_graph[i] = edge_list
250
+ else:
251
+ topk_ui = torch.topk(co_graph_matrix[i], 200)
252
+ edge_list_i = topk_ui.indices.tolist()
253
+ edge_list_j = topk_ui.values.tolist()
254
+ edge_list = [edge_list_i, edge_list_j]
255
+ dict_graph[i] = edge_list
256
+ return dict_graph
257
+
258
+
259
+ # Calculate item similarity, build similarity matrix
260
+ def get_knn_adj_mat(mm_embeddings, knn_k, device):
261
+ # Standardize and calculate similarity
262
+ context_norm = F.normalize(mm_embeddings, dim=1)
263
+ final_sim = torch.mm(context_norm, context_norm.transpose(1, 0)).cpu()
264
+ sim_value, knn_ind = torch.topk(final_sim, knn_k, dim=-1)
265
+ adj_size = final_sim.size()
266
+ # Construct sparse adjacency matrices
267
+ indices0 = torch.arange(knn_ind.shape[0])
268
+ indices0 = torch.unsqueeze(indices0, 1)
269
+ indices0 = indices0.expand(-1, knn_k)
270
+ indices = torch.stack((torch.flatten(indices0), torch.flatten(knn_ind)), 0)
271
+ sim_adj = torch.sparse.FloatTensor(indices, sim_value.flatten(), adj_size).to(device)
272
+ degree_adj = torch.sparse.FloatTensor(indices, torch.ones(indices.shape[1]), adj_size)
273
+ return torch_sparse_tensor_norm_adj(sim_adj, degree_adj, adj_size, device)
274
+
275
+
276
+ def torch_sparse_tensor_norm_adj(sim_adj, degree_adj, adj_size, device):
277
+ """
278
+ :param sim_adj: Tensor adjacency matrix (The value of 0 or 1 is degree normalised; the value of [0,1] is similarity normalised)
279
+ :param degree_adj: Tensor adjacency matrix (The value of 0 or 1 is degree normalised; the value of [0,1] is similarity normalised)
280
+ :param adj_size: Tensor size of adjacency matrix
281
+ :param device: cpu or gpu
282
+ :return: Laplace degree normalised adjacency matrix
283
+ """
284
+ # norm adj matrix,add epsilon to avoid Devide by zero Warning
285
+ row_sum = 1e-7 + torch.sparse.sum(degree_adj, -1).to_dense()
286
+ r_inv_sqrt = torch.pow(row_sum, -0.5)
287
+
288
+ col = torch.arange(adj_size[0])
289
+ row = torch.arange(adj_size[1])
290
+ sp_degree = torch.sparse.FloatTensor(torch.stack((col, row)).to(device), r_inv_sqrt.to(device))
291
+ return torch.spmm((torch.spmm(sp_degree, sim_adj)), sp_degree)
292
+
293
+
294
+ def sparse_mx_to_torch_sparse_tensor(sparse_mx):
295
+ """Convert a scipy sparse matrix to a torch sparse tensor."""
296
+ if type(sparse_mx) != sp.coo_matrix:
297
+ sparse_mx = sparse_mx.tocoo().astype(np.float32)
298
+ indices = torch.from_numpy(
299
+ np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
300
+ values = torch.from_numpy(sparse_mx.data).float()
301
+ shape = torch.Size(sparse_mx.shape)
302
+ return torch.sparse.FloatTensor(indices, values, shape)
303
+
304
+
305
+ def topk_sample(n_ui, dict_graph, k, topK_ui, topK_ui_counts, aggr_mode, device):
306
+ ui_graph_index = []
307
+ user_weight_matrix = torch.zeros(len(dict_graph), k)
308
+ for i in range(len(dict_graph)):
309
+
310
+ if len(dict_graph[i][0]) < k:
311
+ if len(dict_graph[i][0]) != 0:
312
+
313
+ ui_graph_sample = dict_graph[i][0][:k]
314
+ ui_graph_weight = dict_graph[i][1][:k]
315
+ rand_index = np.random.randint(0, len(ui_graph_sample), size=k - len(ui_graph_sample))
316
+ ui_graph_sample += np.array(ui_graph_sample)[rand_index].tolist()
317
+ ui_graph_weight += np.array(ui_graph_weight)[rand_index].tolist()
318
+ ui_graph_index.append(ui_graph_sample)
319
+ else:
320
+ ui_graph_index.append(topK_ui[:k])
321
+ ui_graph_weight = (np.array(topK_ui_counts[:k]) / sum(topK_ui_counts[:k])).tolist()
322
+ else:
323
+ ui_graph_sample = dict_graph[i][0][:k]
324
+ ui_graph_weight = dict_graph[i][1][:k]
325
+ ui_graph_index.append(ui_graph_sample)
326
+
327
+ if aggr_mode == 'softmax':
328
+ user_weight_matrix[i] = F.softmax(torch.tensor(ui_graph_weight), dim=0) # softmax
329
+ elif aggr_mode == 'mean':
330
+ user_weight_matrix[i] = torch.ones(k) / k # mean
331
+
332
+ tmp_all_row = []
333
+ tmp_all_col = []
334
+ for i in range(n_ui):
335
+ row = torch.zeros(1, k) + i
336
+ tmp_all_row += row.flatten()
337
+ tmp_all_col += ui_graph_index[i]
338
+ tmp_all_row = torch.tensor(tmp_all_row).to(torch.int32)
339
+ tmp_all_col = torch.tensor(tmp_all_col).to(torch.int32)
340
+ values = user_weight_matrix.flatten().to(device)
341
+ indices = torch.stack((tmp_all_row, tmp_all_col)).to(device)
342
+ return torch.sparse_coo_tensor(indices, values, (n_ui, n_ui))
343
+
344
+
345
+ def load_or_create_matrix(logger, matrix_type, des, dataset_name, file_name, create_function, *create_args):
346
+ """
347
+ Load a matrix from file if it exists; otherwise, create and save it.
348
+ :param logger: logger
349
+ :param matrix_type: str, type of the matrix (e.g., 'user', 'item').
350
+ :param des: str name of the matrix
351
+ :param dataset_name: str, dataset name used to define the file path.
352
+ :param file_name: str, name of the file to save or load the matrix.
353
+ :param create_function: function, function to call for matrix creation.
354
+ :param create_args: tuple, additional arguments for the create function.
355
+ :return: The loaded or created matrix.
356
+ """
357
+ file_path = os.path.join("data", dataset_name, file_name + ".pt")
358
+
359
+ if os.path.exists(file_path):
360
+ matrix = torch.load(file_path)
361
+ logger.info(f"{matrix_type.capitalize()} " + des + " has been loaded!")
362
+ else:
363
+ logger.info(f"{matrix_type.capitalize()} " + des + " does not exist, creating!")
364
+ matrix = create_function(*create_args)
365
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
366
+ torch.save(matrix, file_path)
367
+ logger.info(f"{matrix_type.capitalize()} " + des + " has been created and saved!")
368
+ return matrix
369
+
370
+
371
+ def propgt_info(ego_feat, n_layers, sp_mat, last_layer=False):
372
+ all_feat = [ego_feat]
373
+ for _ in range(n_layers):
374
+ ego_feat = torch.sparse.mm(sp_mat, ego_feat)
375
+ all_feat += [ego_feat]
376
+ if last_layer:
377
+ return ego_feat
378
+
379
+ all_feat = torch.stack(all_feat, dim=1)
380
+ all_feat = all_feat.mean(dim=1, keepdim=False)
381
+ return all_feat
utils/logger.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+
5
+ def init_logger(config):
6
+ """
7
+ A logger that can show a message on standard output and write it into the
8
+ file named `filename` simultaneously.
9
+ All the message that you want to log MUST be str.
10
+
11
+ Args:
12
+ config (Config): An instance object of Config, used to record parameter information.
13
+ """
14
+ LOGROOT = './log/'
15
+ dir_name = os.path.dirname(LOGROOT)
16
+ if not os.path.exists(dir_name):
17
+ os.makedirs(dir_name)
18
+
19
+ logger = logging.getLogger("normal")
20
+
21
+ name_ = "{}-{}-lr_{}-rww_{}-nl_{}-sdp_{}-mdp_{}-clt_{}-diffw_{}-semw_{}.log"
22
+ logfilename = name_.format(config.model_name, config.dataset, config.learning_rate,
23
+ config.reg_weight, config.n_layers,
24
+ config.s_drop, config.m_drop, config.cl_tmp,
25
+ config.diff_loss_weight, config.cl_loss_weight)
26
+ logfilepath = os.path.join(LOGROOT, logfilename)
27
+ filefmt = "%(asctime)-15s %(message)s"
28
+ filedatefmt = "%a %d %b %Y %H:%M:%S"
29
+
30
+ fileformatter = logging.Formatter(filefmt, filedatefmt)
31
+
32
+ sfmt = u"%(asctime)-15s %(message)s"
33
+ sdatefmt = "%d %b %H:%M"
34
+ sformatter = logging.Formatter(sfmt, sdatefmt)
35
+
36
+ fh = logging.FileHandler(logfilepath, 'w', 'utf-8')
37
+ fh.setFormatter(fileformatter)
38
+
39
+ sh = logging.StreamHandler()
40
+ sh.setFormatter(sformatter)
41
+
42
+ logger.setLevel(logging.INFO)
43
+ logger.handlers = []
44
+ logger.addHandler(fh)
45
+ logger.addHandler(sh)
46
+ return logger
utils/metrics.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def cal_recall(pos_index, pos_len):
5
+ rec_ret = np.cumsum(pos_index, axis=1) / pos_len.reshape(-1, 1)
6
+ return rec_ret.mean(axis=0)
7
+
8
+
9
+ def cal_ndcg(pos_index, pos_len):
10
+ len_rank = np.full_like(pos_len, pos_index.shape[1])
11
+ idcg_len = np.where(pos_len > len_rank, len_rank, pos_len)
12
+
13
+ iranks = np.zeros_like(pos_index, dtype=float)
14
+ iranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
15
+ idcg = np.cumsum(1.0 / np.log2(iranks + 1), axis=1)
16
+ for row, idx in enumerate(idcg_len):
17
+ idcg[row, idx:] = idcg[row, idx - 1]
18
+
19
+ ranks = np.zeros_like(pos_index, dtype=float)
20
+ ranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
21
+ dcg = 1.0 / np.log2(ranks + 1)
22
+ dcg = np.cumsum(np.where(pos_index, dcg, 0), axis=1)
23
+
24
+ result = dcg / idcg
25
+ return result.mean(axis=0)
26
+
27
+
28
+ def cal_map(pos_index, pos_len):
29
+ pre = pos_index.cumsum(axis=1) / np.arange(1, pos_index.shape[1] + 1)
30
+ sum_pre = np.cumsum(pre * pos_index.astype(float), axis=1)
31
+ len_rank = np.full_like(pos_len, pos_index.shape[1])
32
+ actual_len = np.where(pos_len > len_rank, len_rank, pos_len)
33
+ result = np.zeros_like(pos_index, dtype=float)
34
+ for row, lens in enumerate(actual_len):
35
+ ranges = np.arange(1, pos_index.shape[1] + 1)
36
+ ranges[lens:] = ranges[lens - 1]
37
+ result[row] = sum_pre[row] / ranges
38
+ return result.mean(axis=0)
39
+
40
+
41
+ def cal_precision(pos_index, pos_len):
42
+ rec_ret = pos_index.cumsum(axis=1) / np.arange(1, pos_index.shape[1] + 1)
43
+ return rec_ret.mean(axis=0)
44
+
45
+
46
+ """Function name and function mapper.
47
+ Useful when we have to serialize evaluation metric names
48
+ and call the functions based on deserialized names
49
+ """
50
+ metrics_dict = {
51
+ 'Precision': cal_precision,
52
+ 'Recall': cal_recall,
53
+ 'NDCG': cal_ndcg,
54
+ 'MAP': cal_map,
55
+ }
utils/parser.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def parse_args():
5
+ parser = argparse.ArgumentParser(description="Run REARM.")
6
+ parser.add_argument('--seed', type=int, default=2025, help='Seed init.')
7
+ parser.add_argument('--model_name', default='REARM', help='Model name.')
8
+ parser.add_argument('--use_gpu', type=bool, default=True, help='enable CUDA training.')
9
+ parser.add_argument('--gpu_id', type=int, default=0, help='The model of the device running the program')
10
+ parser.add_argument('--dataset', nargs='?', default='baby',
11
+ help='Choose a dataset from {baby, sports, clothing}')
12
+ parser.add_argument('--batch_size', type=int, default=2048, help='Batch size.')
13
+ parser.add_argument('--eval_batch_size', type=int, default=8192, help='The data size of batch evaluation')
14
+ parser.add_argument('--metrics', type=list, default=["Precision", "Recall", "NDCG"],
15
+ help='Choose some from {"Precision", "Recall", "NDCG", "MAP"}')
16
+ parser.add_argument('--topk', type=list, default=[10, 20], help='Metrics scale')
17
+ parser.add_argument('--embedding_dim', type=int, default=64, help='Latent dimension 64.')
18
+ parser.add_argument('--num_epoch', type=int, default=2000, help='Epoch number.')
19
+ parser.add_argument('--num_workers', type=int, default=8, help='Workers number.')
20
+ parser.add_argument('--stopping_step', type=int, default=20, help='early stopping strategy.')
21
+ parser.add_argument('--valid_metric', type=str, default="Recall@20", help='valid metric')
22
+ parser.add_argument('--with_tensorboard', action='store_true', default=False, help='with tensorboard analysis ')
23
+
24
+ parser.add_argument('--l_r', type=float, default=5e-5, help='Learning rate.')
25
+ parser.add_argument('--learning_rate_scheduler', type=list, default=[1.0, 50], help='learning rate scheduler.')
26
+ parser.add_argument('--reg_weight', type=float, default=5e-4, help='regularization weight.')
27
+ parser.add_argument('--num_layer', type=int, default=4, help='Layer number.')
28
+ parser.add_argument('--s_drop', type=float, default=0.4, help='self_attention_dropout.')
29
+ parser.add_argument('--m_drop', type=float, default=0.6, help='mutual_attention_dropout.')
30
+ parser.add_argument('--cl_tmp', type=float, default=0.6, help='Contrast learning temperature coefficient')
31
+ parser.add_argument('--cl_loss_weight', type=float, default=5e-6, help='contrast loss weight.')
32
+ parser.add_argument('--diff_loss_weight', type=float, default=1e-4, help='Structure contrast loss weight.')
33
+ parser.add_argument('--user_knn_k', type=int, default=40,
34
+ help='Select the 10 users most similar to the target users to build the users graph')
35
+ parser.add_argument('--item_knn_k', type=int, default=10,
36
+ help='Select the 10 items most similar to the target item to build the item graph')
37
+
38
+ parser.add_argument('--i_mm_image_weight', type=float, default=0,
39
+ help='The proportion of visual feat in item graph.')
40
+ parser.add_argument('--u_mm_image_weight', type=float, default=0.2,
41
+ help='The proportion of visual feat in user graph.')
42
+ parser.add_argument('--n_ii_layers', type=int, default=1,
43
+ help='Number of layers of item feature propagation in the item graph')
44
+ parser.add_argument('--n_uu_layers', type=int, default=1,
45
+ help='Number of layers of user feature propagation in the user graph')
46
+ parser.add_argument('--user_aggr_mode', type=str, default='softmax',
47
+ help='Choose a modedataset from {softmax, mean}')
48
+
49
+ parser.add_argument('--rank', type=int, default=3, help='the dimension of low rank matrix decomposition')
50
+ parser.add_argument('--uu_co_weight', type=float, default=0.4,
51
+ help='the proportion of user co-occurrence graphs to user homographs')
52
+ parser.add_argument('--ii_co_weight', type=float, default=0.2,
53
+ help='the proportion of item co-occurrence graphs to user homographs')
54
+ return parser.parse_args()