Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import pandas as pd | |
| import torch | |
| import requests | |
| import random | |
| from io import BytesIO | |
| from PIL import Image | |
| from torch_geometric.nn import SAGEConv, to_hetero, Linear | |
| from dotenv import load_dotenv | |
| import os | |
| from IPython.display import HTML | |
| import viz_utils | |
| import model_def | |
| load_dotenv() #load environment variables from .env file | |
| ##no clue why this is necessary. But won't see subfolders without it. Just on my laptop. | |
| os.chdir(os.path.dirname(os.path.abspath(__file__))) | |
| #API_KEY = os.getenv("HUGGINGFACE_API_KEY") | |
| API_KEY = os.environ["HUGGINGFACE_API_KEY"] | |
| API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" | |
| # --- LOAD DATA AND MODEL --- | |
| # map_location forces the model to be loaded on the CPU for huggingface compatibility | |
| movies_df = pd.read_csv("./sampled_movie_dataset/movies_metadata.csv") # Load your movie data | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| data = torch.load("./PyGdata.pt", map_location=device) | |
| model = model_def.Model(hidden_channels=32).to(device) | |
| model.load_state_dict(torch.load("PyGTrainedModelState.pt", map_location=device)) | |
| model.eval() | |
| # --- STREAMLIT APP --- | |
| st.title("Movie Recommendation App") | |
| # --- VISUALIZATIONS --- | |
| with open("./Visualizations/umap_visualization.html", "r", encoding='utf-8') as f: | |
| umap_html = f.read() | |
| with open("./Visualizations/tsne_visualization.html", "r") as f: | |
| tsne_html = f.read() | |
| with open("./Visualizations/pca_visualization.html", "r") as f: | |
| pca_html = f.read() | |
| tab1, tab2 = st.tabs(["Visualizations", "Recommendations"]) | |
| with torch.no_grad(): | |
| a = model.encoder(data.x_dict,data.edge_index_dict) | |
| user = pd.DataFrame(a['user'].detach().cpu()) | |
| movie = pd.DataFrame(a['movie'].detach().cpu()) | |
| embedding_df = pd.concat([user, movie], axis=0) | |
| with tab1: | |
| umap_expander = st.expander("UMAP Visualization") | |
| with umap_expander: | |
| st.subheader('UMAP Visualization') | |
| #umap_fig = viz_utils.visualize_embeddings_umap(embedding_df) | |
| #st.plotly_chart(umap_fig) | |
| components.html(umap_html, width=800, height=800) | |
| tsne_expander = st.expander("TSNE Visualization") | |
| with tsne_expander: | |
| st.subheader('TSNE Visualization') | |
| #tsne_fig = viz_utils.visualize_embeddings_tsne(embedding_df) | |
| #st.plotly_chart(tsne_fig) | |
| components.html(tsne_html, width=800, height=800) | |
| pca_expander = st.expander("PCA Visualization") | |
| with pca_expander: | |
| st.subheader('PCA Visualization') | |
| #pca_fig = viz_utils.visualize_embeddings_pca(embedding_df) | |
| #st.plotly_chart(pca_fig) | |
| components.html(pca_html, width=800, height=800) | |
| def get_movie_recommendations(model, data, user_id, total_movies): | |
| user_row = torch.tensor([user_id] * total_movies).to(device) | |
| all_movie_ids = torch.arange(total_movies).to(device) | |
| edge_label_index = torch.stack([user_row, all_movie_ids], dim=0) | |
| pred = model(data.x_dict, data.edge_index_dict, edge_label_index).to('cpu') | |
| top_five_indices = pred.topk(5).indices | |
| recommended_movies = movies_df.iloc[top_five_indices] | |
| return recommended_movies | |
| def generate_poster(movie_title): | |
| headers = {"Authorization": f"Bearer {API_KEY}"} | |
| #creates random seed so movie poster changes on refresh even if same title. | |
| seed = random.randint(0, 2**32 - 1) | |
| payload = { | |
| "inputs": movie_title, | |
| # "parameters": { | |
| # "seed": seed | |
| # } | |
| } | |
| try: | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| response.raise_for_status() # Raise an error if the request fails | |
| # Display the generated image | |
| image = Image.open(BytesIO(response.content)) | |
| st.image(image, caption=movie_title) | |
| except requests.exceptions.HTTPError as err: | |
| st.error(f"Image generation failed: {err}") | |
| with tab2: | |
| user_id = st.number_input("Enter the User ID:", min_value=0) | |
| if st.button("Get Recommendations"): | |
| st.write("Top 5 Recommendations:") | |
| try: | |
| total_movies = data['movie'].num_nodes | |
| recommended_movies = get_movie_recommendations(model, data, user_id, total_movies) | |
| cols = st.columns(3) | |
| for i, row in recommended_movies.iterrows(): | |
| with cols[i % 3]: | |
| #st.write(f"{i+1}. {row['title']}") | |
| try: | |
| image = generate_poster(row['title']) | |
| except requests.exceptions.HTTPError as err: | |
| st.error(f"Image generation failed for {row['title']}: {err}") | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |