Spaces:
Sleeping
Sleeping
| import umap.umap_ as umap | |
| import plotly.express as px | |
| import pandas as pd | |
| import random | |
| import viz_utils | |
| import torch | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn import Linear | |
| import torch_geometric.transforms as T | |
| from torch_geometric.nn import SAGEConv, to_hetero | |
| from torch_geometric.transforms import RandomLinkSplit, ToUndirected | |
| from sentence_transformers import SentenceTransformer | |
| from torch_geometric.data import HeteroData | |
| import yaml | |
| import os | |
| import model_def | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| data = torch.load("./PyGdata.pt", map_location=device) | |
| movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") | |
| model = model_def.Model(hidden_channels=32).to(device) | |
| model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device)), | |
| model.eval() | |
| total_users = data['user'].num_nodes | |
| total_movies = data['movie'].num_nodes | |
| print("total users =", total_users) | |
| print("total movies =", total_movies) | |
| with torch.no_grad(): | |
| a = model.encoder(data.x_dict,data.edge_index_dict) | |
| user = pd.DataFrame(a['user'].detach().cpu()) | |
| movie = pd.DataFrame(a['movie'].detach().cpu()) | |
| embedding_df = pd.concat([user, movie], axis=0) | |
| movie_index = 20 | |
| title = movies_df.iloc[movie_index]['title'] | |
| print(title) | |
| os.makedirs("Visualizations", exist_ok=True) | |
| fig_umap = viz_utils.visualize_embeddings_umap(embedding_df) | |
| viz_utils.save_visualization(fig_umap, "./Visualizations/umap_visualization") | |
| print("UMAP visualization saved") | |
| fig_tsne = viz_utils.visualize_embeddings_tsne(embedding_df) | |
| viz_utils.save_visualization(fig_tsne, "./Visualizations/tsne_visualization") | |
| print("TSNE visualization saved") | |
| fig_pca = viz_utils.visualize_embeddings_pca(embedding_df) | |
| viz_utils.save_visualization(fig_pca, "./Visualizations/pca_visualization") | |
| print("PCA visualization saved") |