erermeev-d commited on
Commit ·
0d80f56
1
Parent(s): 75a5562
Added random seed fixing
Browse files- exp/gnn/train.py +5 -1
- exp/gnn/utils.py +12 -0
exp/gnn/train.py
CHANGED
|
@@ -15,11 +15,14 @@ from exp.evaluate import evaluate_recsys
|
|
| 15 |
from exp.gnn.model import GNNModel
|
| 16 |
from exp.gnn.loss import nt_xent_loss
|
| 17 |
from exp.gnn.utils import (
|
| 18 |
-
prepare_graphs, LRSchedule,
|
| 19 |
sample_item_batch, inference_model)
|
| 20 |
|
| 21 |
|
| 22 |
def prepare_gnn_embeddings(config):
|
|
|
|
|
|
|
|
|
|
| 23 |
### Prepare graph
|
| 24 |
bipartite_graph, _ = prepare_graphs(config["items_path"], config["train_ratings_path"])
|
| 25 |
bipartite_graph = bipartite_graph.to(config["device"])
|
|
@@ -127,6 +130,7 @@ if __name__ == "__main__":
|
|
| 127 |
parser.add_argument("--num_neighbor", type=int, default=10, help="Number of neighbors in PinSAGE-like sampler")
|
| 128 |
|
| 129 |
# Misc
|
|
|
|
| 130 |
parser.add_argument("--validate_every_n_epoch", type=int, default=4, help="Perform RecSys validation every n train epochs.")
|
| 131 |
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
|
| 132 |
parser.add_argument("--wandb_name", type=str, help="WandB run name")
|
|
|
|
| 15 |
from exp.gnn.model import GNNModel
|
| 16 |
from exp.gnn.loss import nt_xent_loss
|
| 17 |
from exp.gnn.utils import (
|
| 18 |
+
prepare_graphs, LRSchedule, fix_random,
|
| 19 |
sample_item_batch, inference_model)
|
| 20 |
|
| 21 |
|
| 22 |
def prepare_gnn_embeddings(config):
|
| 23 |
+
### Fix random seed
|
| 24 |
+
fix_random(config["seed"])
|
| 25 |
+
|
| 26 |
### Prepare graph
|
| 27 |
bipartite_graph, _ = prepare_graphs(config["items_path"], config["train_ratings_path"])
|
| 28 |
bipartite_graph = bipartite_graph.to(config["device"])
|
|
|
|
| 130 |
parser.add_argument("--num_neighbor", type=int, default=10, help="Number of neighbors in PinSAGE-like sampler")
|
| 131 |
|
| 132 |
# Misc
|
| 133 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
| 134 |
parser.add_argument("--validate_every_n_epoch", type=int, default=4, help="Perform RecSys validation every n train epochs.")
|
| 135 |
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)")
|
| 136 |
parser.add_argument("--wandb_name", type=str, help="WandB run name")
|
exp/gnn/utils.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import dgl
|
| 3 |
import pandas as pd
|
|
@@ -7,6 +9,16 @@ from tqdm.auto import tqdm
|
|
| 7 |
from exp.utils import normalize_embeddings
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class LRSchedule:
|
| 11 |
def __init__(self, total_steps, warmup_steps, final_factor):
|
| 12 |
self._total_steps = total_steps
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import dgl
|
| 5 |
import pandas as pd
|
|
|
|
| 9 |
from exp.utils import normalize_embeddings
|
| 10 |
|
| 11 |
|
| 12 |
+
def fix_random(seed):
|
| 13 |
+
dgl.seed(seed)
|
| 14 |
+
torch.random.manual_seed(seed)
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
random.seed(seed)
|
| 17 |
+
if torch.cuda.is_available():
|
| 18 |
+
torch.cuda.manual_seed(seed)
|
| 19 |
+
torch.cuda.manual_seed_all(seed)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
class LRSchedule:
|
| 23 |
def __init__(self, total_steps, warmup_steps, final_factor):
|
| 24 |
self._total_steps = total_steps
|