Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import pickle | |
| from torch_geometric.nn import GCNConv, LGConv | |
| from torch_geometric.utils import degree | |
| from torch_geometric.nn.conv import MessagePassing | |
| from torch_geometric.data import HeteroData, Data | |
| import torch_geometric.transforms as T | |
| from torch_geometric.nn import LightGCN | |
| import utils | |
| device = torch.device('cpu') | |
| data = torch.load("processed_MVL_light.pt", map_location=torch.device('cpu')) | |
| ch = torch.load('./lightGCNModel_num_layers_MovieLens100K_checkpoint.pt', map_location=torch.device('cpu')) | |
| lightGCNModel = LightGCN( | |
| num_nodes=data.num_nodes, | |
| embedding_dim=64, | |
| num_layers=3, | |
| ).to(device) | |
| # optimizer = torch.optim.Adam(lightGCNModel.parameters(), lr=0.005) | |
| mask_train = data.edge_index[0] < data.edge_index[1] | |
| train_edge_label_index = data.edge_index[:, mask_train] | |
| lightGCNModel.load_state_dict(ch['model_state_dict']) | |
| # optimizer.load_state_dict(ch['optimizer_state_dict']) | |
| num_items = 1682 | |
| num_users = 943 | |
| def recommend(user_id): | |
| ground_truth_items, recommendations = utils.predict(lightGCNModel, device, data, num_users, num_items, user_id, train_edge_label_index, k=5) | |
| return ' '.join(ground_truth_items['title'].tolist()), ' '.join(recommendations) | |
| iface = gr.Interface(fn=recommend, inputs="number", outputs=["text", "text"]) | |
| iface.launch() |