File size: 2,503 Bytes
f60c555 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import os
import torch
import multiprocessing
from utils.helper import init_seed
class Config(object):
def __init__(self, args):
self.model_name = args.model_name
self.dataset = args.dataset
self.learning_rate = args.l_r
self.learning_rate_scheduler = args.learning_rate_scheduler
self.embedding_dim = args.embedding_dim
self.num_epoch = args.num_epoch
self.reg_weight = args.reg_weight
self.use_gpu = args.use_gpu
self.gpu_id = args.gpu_id
self.seed = args.seed
self.batch_size = args.batch_size
self.eval_batch_size = args.eval_batch_size
self.topk = args.topk
self.valid_metric = args.valid_metric
self.metrics = args.metrics
self.stopping_step = args.stopping_step
self.n_layers = args.num_layer
self.rank = args.rank
self.s_drop = args.s_drop
self.m_drop = args.m_drop
self.cl_tmp = args.cl_tmp
self.item_knn_k = args.item_knn_k
self.user_knn_k = args.user_knn_k
self.num_user_co = args.user_knn_k # same as user_knn_k to compute
self.num_item_co = args.item_knn_k # same as item_knn_k to compute
self.n_ii_layers = args.n_ii_layers
self.n_uu_layers = args.n_uu_layers
self.writer = args.with_tensorboard
self.uu_co_weight = args.uu_co_weight
self.ii_co_weight = args.ii_co_weight
self.cl_loss_weight = args.cl_loss_weight
self.user_aggr_mode = args.user_aggr_mode
self.i_mm_image_weight = args.i_mm_image_weight
self.u_mm_image_weight = args.u_mm_image_weight
self.diff_loss_weight = args.diff_loss_weight
self._init_device(args)
init_seed(self.seed)
def _init_device(self, args):
if self.use_gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
self.device = torch.device("cuda" if torch.cuda.is_available() and self.use_gpu else "cpu")
# Ensure that setting up multiple threads does not exceed
max_cpu_count = multiprocessing.cpu_count()
self.num_workers = max_cpu_count // 2 if max_cpu_count // 2 < args.num_workers else args.num_workers
def __str__(self):
args_info = '\nModel arguments: '
args_info += ',\n'.join(["{} = {}".format(arg, value) for arg, value in self.__dict__.items()])
args_info += '.\n'
return args_info
|