MohammadYaseen commited on
Commit
7dde2e4
·
verified ·
1 Parent(s): 1f831f3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Import Required Libraries ===
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import streamlit as st
7
+ import gradio as gr
8
+ import os
9
+
10
+ # === Dataset Loading Function ===
11
+ def load_dataset():
12
+ """
13
+ Provides multiple options to load the dataset: manual upload, Kaggle download, or specifying a local path.
14
+ """
15
+ st.write("### Dataset Upload Options")
16
+ upload_option = st.radio(
17
+ "Choose how to provide the dataset:",
18
+ ("Manual Upload", "Download from Kaggle", "Specify Local Path")
19
+ )
20
+
21
+ # Manual Upload
22
+ if upload_option == "Manual Upload":
23
+ st.write("#### Upload the file below:")
24
+ uploaded_file = st.file_uploader("Upload your CSV file", type="csv")
25
+ if uploaded_file is not None:
26
+ st.success("File uploaded successfully!")
27
+ return pd.read_csv(uploaded_file)
28
+
29
+ # Kaggle Download
30
+ elif upload_option == "Download from Kaggle":
31
+ st.write("#### Enter your Kaggle Dataset Path and API Key")
32
+ kaggle_dataset = st.text_input("Kaggle Dataset Path (e.g., `thedevastator/hydra-movies-dataset-directors-writers-cast-and`):")
33
+ kaggle_api_key = st.text_area("Enter your Kaggle API Key JSON content:")
34
+ if st.button("Download Dataset"):
35
+ if kaggle_dataset and kaggle_api_key:
36
+ # Set up Kaggle API
37
+ os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)
38
+ with open(os.path.expanduser("~/.kaggle/kaggle.json"), "w") as f:
39
+ f.write(kaggle_api_key)
40
+ os.chmod(os.path.expanduser("~/.kaggle/kaggle.json"), 0o600)
41
+
42
+ # Download dataset
43
+ os.system(f"!kaggle datasets download -d {kaggle_dataset} --unzip")
44
+ dataset_name = kaggle_dataset.split("/")[-1] + ".csv"
45
+ if os.path.exists(dataset_name):
46
+ st.success(f"Dataset {dataset_name} downloaded successfully!")
47
+ return pd.read_csv(dataset_name)
48
+ else:
49
+ st.error("Failed to download dataset. Please check your inputs.")
50
+ else:
51
+ st.warning("Please provide both the dataset path and your API key.")
52
+
53
+ # Specify Local Path
54
+ elif upload_option == "Specify Local Path":
55
+ local_path = st.text_input("Specify the full local path of your CSV file:")
56
+ if st.button("Load Dataset"):
57
+ if os.path.exists(local_path):
58
+ st.success("Dataset loaded successfully from the specified path!")
59
+ return pd.read_csv(local_path)
60
+ else:
61
+ st.error("File not found. Please check the path and try again.")
62
+
63
+ return None
64
+
65
+ # === Preprocess Data ===
66
+ def preprocess_data(df):
67
+ """
68
+ Normalizes column names and prepares text for embeddings. Adds placeholders for missing columns if needed.
69
+ """
70
+ # Normalize column names
71
+ df.columns = df.columns.str.strip().str.lower()
72
+
73
+ # Verify and handle missing 'genres' column
74
+ if 'genres' not in df.columns:
75
+ print("Warning: 'genres' column missing! Adding a placeholder.")
76
+ df['genres'] = "Unknown"
77
+
78
+ # Check if columns like 'title', 'summary', 'cast' exist and handle possible NaN/invalid values
79
+ df['text'] = df['title'].fillna('') + " " + df['summary'].fillna('') + " " + df['genres'] + " " + df['cast'].fillna('')
80
+
81
+ return df
82
+
83
+ # === Create Embeddings and FAISS Index ===
84
+ def create_faiss_index(df, model):
85
+ """
86
+ Generates embeddings using a sentence-transformer model and creates a FAISS index.
87
+ """
88
+ embeddings = model.encode(df['text'].tolist(), show_progress_bar=True)
89
+ dimension = embeddings.shape[1]
90
+ index = faiss.IndexFlatL2(dimension)
91
+ index.add(embeddings)
92
+ return index
93
+
94
+ # === Define Retrieval Function ===
95
+ def retrieve(query, model, index, df, top_k=5):
96
+ """
97
+ Retrieves top-k results for a given query using FAISS index.
98
+ """
99
+ query_embedding = model.encode([query])
100
+ distances, indices = index.search(query_embedding, top_k)
101
+ results = df.iloc[indices[0]].to_dict(orient="records")
102
+ return results
103
+
104
+ # === Define Gradio Interface ===
105
+ def movie_query_app(query, model, index, df):
106
+ """
107
+ Gradio interface function to retrieve and display movie recommendations based on a query.
108
+ """
109
+ results = retrieve(query, model, index, df)
110
+ response = ""
111
+ for i, res in enumerate(results):
112
+ response += f"**{i+1}. {res['title']} ({res['year']})**\n"
113
+ response += f"- **Genres**: {res['genres']}\n"
114
+ response += f"- **Summary**: {res['short summary']}\n"
115
+ response += f"- **Director**: {res['director']}\n"
116
+ response += f"- **Cast**: {res['cast']}\n"
117
+ response += f"- **Rating**: {res['rating']}\n\n"
118
+ return response
119
+
120
+ # === Main Function ===
121
+ if __name__ == "__main__":
122
+ # Streamlit Setup
123
+ st.title("RAG Application with Integrated Dataset Loading")
124
+
125
+ # Step 1: Load dataset
126
+ df = load_dataset()
127
+
128
+ if df is not None:
129
+ st.write("### Preview of Loaded Dataset")
130
+ st.dataframe(df.head())
131
+
132
+ # Step 2: Preprocess data
133
+ df = preprocess_data(df)
134
+
135
+ # Step 3: Create embeddings and FAISS index
136
+ st.write("### Creating Embeddings and Index...")
137
+ model = SentenceTransformer('all-MiniLM-L6-v2')
138
+ index = create_faiss_index(df, model)
139
+
140
+ # Step 4: Set up Gradio interface
141
+ iface = gr.Interface(
142
+ fn=lambda query: movie_query_app(query, model, index, df),
143
+ inputs="text",
144
+ outputs="text",
145
+ title="Movie Recommendation App",
146
+ )
147
+
148
+ # Step 5: Launch the app
149
+ st.write("### Launching Gradio App...")
150
+ iface.launch()
151
+ else:
152
+ st.write("### Please load the dataset to proceed.")