|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
import streamlit as st |
|
|
import os |
|
|
|
|
|
|
|
|
def load_dataset(): |
|
|
""" |
|
|
Provides multiple options to load the dataset: manual upload, Kaggle download, or specifying a local path. |
|
|
""" |
|
|
st.write("### Dataset Upload Options") |
|
|
upload_option = st.radio( |
|
|
"Choose how to provide the dataset:", |
|
|
("Manual Upload", "Download from Kaggle", "Specify Local Path") |
|
|
) |
|
|
|
|
|
|
|
|
if upload_option == "Manual Upload": |
|
|
st.write("#### Upload the file below:") |
|
|
uploaded_file = st.file_uploader("Upload your CSV file", type="csv") |
|
|
if uploaded_file is not None: |
|
|
st.success("File uploaded successfully!") |
|
|
return pd.read_csv(uploaded_file) |
|
|
|
|
|
|
|
|
elif upload_option == "Download from Kaggle": |
|
|
st.write("#### Enter your Kaggle Dataset Path and API Key") |
|
|
kaggle_dataset = st.text_input("Kaggle Dataset Path (e.g., `thedevastator/hydra-movies-dataset-directors-writers-cast-and`):") |
|
|
kaggle_api_key = st.text_area("Enter your Kaggle API Key JSON content:") |
|
|
if st.button("Download Dataset"): |
|
|
if kaggle_dataset and kaggle_api_key: |
|
|
|
|
|
os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True) |
|
|
with open(os.path.expanduser("~/.kaggle/kaggle.json"), "w") as f: |
|
|
f.write(kaggle_api_key) |
|
|
os.chmod(os.path.expanduser("~/.kaggle/kaggle.json"), 0o600) |
|
|
|
|
|
|
|
|
os.system(f"!kaggle datasets download -d {kaggle_dataset} --unzip") |
|
|
dataset_name = kaggle_dataset.split("/")[-1] + ".csv" |
|
|
if os.path.exists(dataset_name): |
|
|
st.success(f"Dataset {dataset_name} downloaded successfully!") |
|
|
return pd.read_csv(dataset_name) |
|
|
else: |
|
|
st.error("Failed to download dataset. Please check your inputs.") |
|
|
else: |
|
|
st.warning("Please provide both the dataset path and your API key.") |
|
|
|
|
|
|
|
|
elif upload_option == "Specify Local Path": |
|
|
local_path = st.text_input("Specify the full local path of your CSV file:") |
|
|
if st.button("Load Dataset"): |
|
|
if os.path.exists(local_path): |
|
|
st.success("Dataset loaded successfully from the specified path!") |
|
|
return pd.read_csv(local_path) |
|
|
else: |
|
|
st.error("File not found. Please check the path and try again.") |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def preprocess_data(df): |
|
|
""" |
|
|
Normalizes column names and prepares text for embeddings. Adds placeholders for missing columns if needed. |
|
|
""" |
|
|
|
|
|
df.columns = df.columns.str.strip().str.lower() |
|
|
|
|
|
|
|
|
if 'genres' not in df.columns: |
|
|
print("Warning: 'genres' column missing! Adding a placeholder.") |
|
|
df['genres'] = "Unknown" |
|
|
|
|
|
|
|
|
df['text'] = df['title'].fillna('') + " " + df['summary'].fillna('') + " " + df['genres'] + " " + df['cast'].fillna('') |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
def create_faiss_index(df, model): |
|
|
""" |
|
|
Generates embeddings using a sentence-transformer model and creates a FAISS index. |
|
|
""" |
|
|
embeddings = model.encode(df['text'].tolist(), show_progress_bar=True) |
|
|
dimension = embeddings.shape[1] |
|
|
index = faiss.IndexFlatL2(dimension) |
|
|
index.add(embeddings) |
|
|
return index |
|
|
|
|
|
|
|
|
def retrieve(query, model, index, df, top_k=5): |
|
|
""" |
|
|
Retrieves top-k results for a given query using FAISS index. |
|
|
""" |
|
|
query_embedding = model.encode([query]) |
|
|
distances, indices = index.search(query_embedding, top_k) |
|
|
results = df.iloc[indices[0]].to_dict(orient="records") |
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
st.title("Movie Recommendation Application with FAISS and Sentence-Transformers") |
|
|
|
|
|
|
|
|
df = load_dataset() |
|
|
|
|
|
if df is not None: |
|
|
st.write("### Preview of Loaded Dataset") |
|
|
st.dataframe(df.head()) |
|
|
|
|
|
|
|
|
df = preprocess_data(df) |
|
|
|
|
|
|
|
|
st.write("### Creating Embeddings and Index...") |
|
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
index = create_faiss_index(df, model) |
|
|
|
|
|
|
|
|
query = st.text_input("Enter a movie name or keyword for recommendations:") |
|
|
|
|
|
if query: |
|
|
st.write("### Query Results") |
|
|
results = retrieve(query, model, index, df) |
|
|
response = "" |
|
|
for i, res in enumerate(results): |
|
|
response += f"**{i+1}. {res['title']} ({res['year']})**\n" |
|
|
response += f"- **Genres**: {res['genres']}\n" |
|
|
response += f"- **Summary**: {res['short summary']}\n" |
|
|
response += f"- **Director**: {res['director']}\n" |
|
|
response += f"- **Cast**: {res['cast']}\n" |
|
|
response += f"- **Rating**: {res['rating']}\n\n" |
|
|
st.write(response) |
|
|
|
|
|
else: |
|
|
st.write("### Please load the dataset to proceed.") |
|
|
|