chirag0107 commited on
Commit
6ba8078
·
1 Parent(s): 8595104

Added required files

Browse files
Files changed (3) hide show
  1. app.py +195 -0
  2. connect.py +46 -0
  3. requirements.txt +90 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import gradio as gr
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from langchain_huggingface import HuggingFaceEndpoint
6
+ from langchain_core.prompts import PromptTemplate
7
+
8
+ from connect import DBConnect
9
+
10
+ __author__ = "Chirag Kamble"
11
+
12
+
13
+ class GradioDashboard:
14
+ """
15
+ Class to generate a simple Gradio Dashboard
16
+ """
17
+ def __init__(self):
18
+ """
19
+ Initialize variable instances and methods
20
+ """
21
+ load_dotenv()
22
+
23
+ self.mongodb_vector_store, self.movies = DBConnect().connect_db()
24
+ self.genres = ["All"] + sorted(self.movies["genre"].apply(lambda x: x.capitalize()).unique())
25
+ self.vibe = ["Neutral", "Happy", "Mind-Bending", "Scary", "In the feels..."]
26
+ self.huggingface_text_generation_model: str = os.getenv("HUGGINGFACE_TEXT_GENERATION_MODEL")
27
+ self.huggingface_api_token: str = os.getenv("HF_TOKEN")
28
+
29
+ self.generate_dashboard()
30
+
31
+ def query_data(self, query: str):
32
+ """
33
+ Movie Script Generation method to Query data from Atlas Vector Search
34
+ :param query: A user query to search
35
+ :return llm_answer: String answer generated by the LLM
36
+ """
37
+ if len(query) == 0:
38
+ raise gr.Error("Enter a prompt to generate a response !", duration=5)
39
+
40
+ hf_llm: HuggingFaceEndpoint = HuggingFaceEndpoint(
41
+ repo_id=self.huggingface_text_generation_model,
42
+ huggingfacehub_api_token=self.huggingface_api_token,
43
+ temperature=0.1,
44
+ task="text-generation",
45
+ repetition_penalty=1.03,
46
+ top_k=10,
47
+ top_p=0.95,
48
+ typical_p=0.95,
49
+ )
50
+
51
+ prompt = PromptTemplate.from_template(
52
+ template="Generate a movie plot based on the below user query.\nBe creative but stay true to the "
53
+ "description provided.\nUser Query:{context}",
54
+ )
55
+
56
+ formatted_prompt = prompt.format(context=query)
57
+ llm_answer = hf_llm.invoke(formatted_prompt)
58
+ llm_answer = llm_answer.split("\n", 1)[1]
59
+
60
+ return llm_answer
61
+
62
+ def retrieve_recommendations(self, query, genre, vibe, initial_top_k=50, final_top_k=10) -> pd.DataFrame:
63
+ """
64
+ Method to retrieve the recommendation from the vector database
65
+ :param query: User query
66
+ :param genre: List of genres available
67
+ :param vibe: List of vibes options available
68
+ :param initial_top_k: Initial number of searched and selected movies
69
+ :param final_top_k: Final number of recommended movies
70
+
71
+ :return movies_recs: Final Dataframe of recommended movies
72
+ """
73
+ recs = self.mongodb_vector_store.similarity_search(query, k=initial_top_k)
74
+ movies_list = [rec.page_content.strip('"').split()[0] for rec in recs]
75
+ movies_recs = self.movies[self.movies["uuid"].isin(movies_list)].head(initial_top_k)
76
+
77
+ if genre != "All":
78
+ movies_recs = movies_recs[movies_recs["genre"] == genre][: final_top_k]
79
+ else:
80
+ movies_recs = movies_recs.head(final_top_k)
81
+
82
+ if vibe == "Balanced":
83
+ movies_recs.sort_values(by="neutral", ascending=False, inplace=True)
84
+ elif vibe == "Happy":
85
+ movies_recs.sort_values(by="joy", ascending=False, inplace=True)
86
+ elif vibe == "Mind-Bending":
87
+ movies_recs.sort_values(by="surprise", ascending=False, inplace=True)
88
+ # elif vibe == "Rage":
89
+ # movies_recs.sort_values(by="anger", ascending=False, inplace=True)
90
+ elif vibe == "Scary":
91
+ movies_recs.sort_values(by="fear", ascending=False, inplace=True)
92
+ elif vibe == "In the feels":
93
+ movies_recs.sort_values(by="sadness", ascending=False, inplace=True)
94
+ # elif vibe == "Gruesome":
95
+ # movies_recs.sort_values(by="disgust", ascending=False, inplace=True)
96
+
97
+ return movies_recs
98
+
99
+ def recommend_movies(self, query: str, genre: str, vibe: str) -> str:
100
+ """
101
+ Method to generate a string with the list of selected movies recommended
102
+ :param query: User query
103
+ :param genre: List of Genres available
104
+ :param vibe: List of Vibe options available
105
+
106
+ :return output: String with the list of recommended movies
107
+ """
108
+ recommendations = self.retrieve_recommendations(query, genre, vibe)
109
+
110
+ results = []
111
+ for i in range(len(recommendations)):
112
+ row = recommendations.iloc[i]
113
+
114
+ plot_split = row["plot"].split()
115
+ truncated_plot = " ".join(plot_split[:30]) + "..."
116
+
117
+ director_split = row["director"].split(",")
118
+
119
+ if len(director_split) > 2:
120
+ directors = f"{', '.join(director_split[:-1])} and {director_split[-1]}"
121
+ elif len(director_split) == 2:
122
+ directors = "and".join(director_split)
123
+ else:
124
+ directors = row["director"]
125
+
126
+ caption = f"{i+1}. {row['title']} by {directors}: {truncated_plot}"
127
+
128
+ results.append(caption)
129
+
130
+ if len(results) == 0:
131
+ output = "Sorry, our database movies does not have recommendations for the chosen Genre and Vibe :("
132
+ else:
133
+ output = "\n\n\n".join(results)
134
+
135
+ return output
136
+
137
+ def generate_dashboard(self):
138
+ theme = gr.themes.Citrus()
139
+ with gr.Blocks(theme=theme) as dashboard:
140
+ gr.Markdown("# Get Movies Recommendations or Generate Your Own Movie Script !!!")
141
+ with gr.Tab(label="Movies Recommender"):
142
+ gr.Markdown("# Movies Recommender")
143
+
144
+ with gr.Row():
145
+ with gr.Column():
146
+ genre_dropdown = gr.Dropdown(choices=self.genres, label="Select A Genre", value="All")
147
+ with gr.Column():
148
+ vibe_dropdown = gr.Dropdown(choices=self.vibe, label="Choose Your Vibe", value="Neutral")
149
+
150
+ with gr.Row():
151
+ user_query = gr.Textbox(label="Please enter a description of the movie you would like to watch:",
152
+ placeholder="e.g. A story about love in war")
153
+
154
+ with gr.Row():
155
+ submit_button = gr.Button("Recommend")
156
+
157
+ gr.Markdown("## Recommendations")
158
+
159
+ with gr.Row():
160
+ output = gr.TextArea(interactive=False,
161
+ label="Your recommendations will be displayed below:",
162
+ autoscroll=False,
163
+ show_label=True,
164
+ show_copy_button=True, )
165
+
166
+ submit_button.click(fn=self.recommend_movies,
167
+ inputs=[user_query, genre_dropdown, vibe_dropdown],
168
+ outputs=[output], )
169
+
170
+ with gr.Tab("Movie Script Generator"):
171
+ gr.Markdown("# Movie Script Generator")
172
+
173
+ with gr.Row():
174
+ script_gen_query_textbox = gr.Textbox(label="Enter your prompt here:", lines=1,
175
+ placeholder="e.g. Generate a movie where a couple "
176
+ "discovers love during a war")
177
+
178
+ with gr.Row():
179
+ button = gr.Button("Generate")
180
+
181
+ with gr.Column():
182
+ output = gr.TextArea(interactive=False,
183
+ placeholder="Your Movie Plot will be displayed here. "
184
+ "Don't forget to invite us to your movie premier! :)",
185
+ autoscroll=False,
186
+ show_label=False,
187
+ )
188
+
189
+ button.click(fn=self.query_data, inputs=[script_gen_query_textbox], outputs=[output])
190
+
191
+ dashboard.launch(debug=True)
192
+
193
+
194
+ if __name__ == "__main__":
195
+ GradioDashboard()
connect.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import pandas as pd
3
+ import os
4
+ import pymongo
5
+ from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+
8
+ __author__ = "Chirag Kamble"
9
+
10
+ class DBConnect:
11
+ """
12
+ Class to connect to the database
13
+ """
14
+ @staticmethod
15
+ def connect_db():
16
+ """
17
+ Static method to connect to the database and create a vector store
18
+ :return: mongodb_vector_store: MongoDB Atlas Vector Store instance connected to the required mongodb collection
19
+ :return: movies: dataframe containing all movies in the database
20
+ """
21
+ load_dotenv()
22
+
23
+ mongodb_connection_url = os.getenv("MONGODB_CONNECTION_URL")
24
+ mongodb_db_name: str = os.getenv("MONGODB_DB_NAME")
25
+ mongodb_collection_name: str = os.getenv("MONGODB_COLLECTION_NAME")
26
+ mongodb_vector_index: str = os.getenv("MONGODB_VECTOR_INDEX_NAME")
27
+ text_key: str = os.getenv("TEXT_KEY")
28
+ embedding_key: str = os.getenv("EMBEDDING_KEY")
29
+ relevance_score_fn = os.getenv("RELEVANCE_SCORE_FN")
30
+
31
+ client = pymongo.MongoClient(mongodb_connection_url)
32
+ db = client[mongodb_db_name]
33
+ collection = db[mongodb_collection_name]
34
+
35
+ mongodb_vector_store = MongoDBAtlasVectorSearch(collection=collection,
36
+ embedding=HuggingFaceEmbeddings(),
37
+ index_name=mongodb_vector_index,
38
+ relevance_score_fn=relevance_score_fn,
39
+ text_key=text_key,
40
+ embedding_key=embedding_key,
41
+ )
42
+
43
+ movies_docs = collection.find()
44
+ movies = pd.DataFrame(movies_docs)
45
+
46
+ return mongodb_vector_store, movies
requirements.txt ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.6
3
+ aiohttp==3.11.12
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ anyio==4.8.0
7
+ attrs==25.1.0
8
+ certifi==2025.1.31
9
+ charset-normalizer==3.4.1
10
+ click==8.1.8
11
+ colorama==0.4.6
12
+ dnspython==2.7.0
13
+ fastapi==0.115.8
14
+ ffmpy==0.5.0
15
+ filelock==3.17.0
16
+ frozenlist==1.5.0
17
+ fsspec==2025.2.0
18
+ gradio==5.15.0
19
+ gradio_client==1.7.0
20
+ greenlet==3.1.1
21
+ h11==0.14.0
22
+ httpcore==1.0.7
23
+ httpx==0.28.1
24
+ huggingface-hub==0.28.1
25
+ idna==3.10
26
+ Jinja2==3.1.5
27
+ joblib==1.4.2
28
+ jsonpatch==1.33
29
+ jsonpointer==3.0.0
30
+ langchain==0.3.18
31
+ langchain-core==0.3.34
32
+ langchain-huggingface==0.1.2
33
+ langchain-mongodb==0.4.0
34
+ langchain-text-splitters==0.3.6
35
+ langsmith==0.3.7
36
+ markdown-it-py==3.0.0
37
+ MarkupSafe==2.1.5
38
+ mdurl==0.1.2
39
+ mpmath==1.3.0
40
+ multidict==6.1.0
41
+ networkx==3.4.2
42
+ numpy==2.2.2
43
+ orjson==3.10.15
44
+ packaging==24.2
45
+ pandas==2.2.3
46
+ pillow==11.1.0
47
+ propcache==0.2.1
48
+ pydantic==2.10.6
49
+ pydantic_core==2.27.2
50
+ pydub==0.25.1
51
+ Pygments==2.19.1
52
+ pymongo==4.11
53
+ python-dateutil==2.9.0.post0
54
+ python-dotenv==1.0.1
55
+ python-multipart==0.0.20
56
+ pytz==2025.1
57
+ PyYAML==6.0.2
58
+ regex==2024.11.6
59
+ requests==2.32.3
60
+ requests-toolbelt==1.0.0
61
+ rich==13.9.4
62
+ ruff==0.9.5
63
+ safehttpx==0.1.6
64
+ safetensors==0.5.2
65
+ scikit-learn==1.6.1
66
+ scipy==1.15.1
67
+ semantic-version==2.10.0
68
+ sentence-transformers==3.4.1
69
+ setuptools==75.8.0
70
+ shellingham==1.5.4
71
+ six==1.17.0
72
+ sniffio==1.3.1
73
+ SQLAlchemy==2.0.38
74
+ starlette==0.45.3
75
+ sympy==1.13.1
76
+ tenacity==9.0.0
77
+ threadpoolctl==3.5.0
78
+ tokenizers==0.21.0
79
+ tomlkit==0.13.2
80
+ torch==2.6.0
81
+ tqdm==4.67.1
82
+ transformers==4.48.3
83
+ typer==0.15.1
84
+ typing_extensions==4.12.2
85
+ tzdata==2025.1
86
+ urllib3==2.3.0
87
+ uvicorn==0.34.0
88
+ websockets==14.2
89
+ yarl==1.18.3
90
+ zstandard==0.23.0