chirag0107's picture
Added required files
6ba8078
import pandas as pd
import gradio as gr
import os
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from connect import DBConnect
__author__ = "Chirag Kamble"
class GradioDashboard:
"""
Class to generate a simple Gradio Dashboard
"""
def __init__(self):
"""
Initialize variable instances and methods
"""
load_dotenv()
self.mongodb_vector_store, self.movies = DBConnect().connect_db()
self.genres = ["All"] + sorted(self.movies["genre"].apply(lambda x: x.capitalize()).unique())
self.vibe = ["Neutral", "Happy", "Mind-Bending", "Scary", "In the feels..."]
self.huggingface_text_generation_model: str = os.getenv("HUGGINGFACE_TEXT_GENERATION_MODEL")
self.huggingface_api_token: str = os.getenv("HF_TOKEN")
self.generate_dashboard()
def query_data(self, query: str):
"""
Movie Script Generation method to Query data from Atlas Vector Search
:param query: A user query to search
:return llm_answer: String answer generated by the LLM
"""
if len(query) == 0:
raise gr.Error("Enter a prompt to generate a response !", duration=5)
hf_llm: HuggingFaceEndpoint = HuggingFaceEndpoint(
repo_id=self.huggingface_text_generation_model,
huggingfacehub_api_token=self.huggingface_api_token,
temperature=0.1,
task="text-generation",
repetition_penalty=1.03,
top_k=10,
top_p=0.95,
typical_p=0.95,
)
prompt = PromptTemplate.from_template(
template="Generate a movie plot based on the below user query.\nBe creative but stay true to the "
"description provided.\nUser Query:{context}",
)
formatted_prompt = prompt.format(context=query)
llm_answer = hf_llm.invoke(formatted_prompt)
llm_answer = llm_answer.split("\n", 1)[1]
return llm_answer
def retrieve_recommendations(self, query, genre, vibe, initial_top_k=50, final_top_k=10) -> pd.DataFrame:
"""
Method to retrieve the recommendation from the vector database
:param query: User query
:param genre: List of genres available
:param vibe: List of vibes options available
:param initial_top_k: Initial number of searched and selected movies
:param final_top_k: Final number of recommended movies
:return movies_recs: Final Dataframe of recommended movies
"""
recs = self.mongodb_vector_store.similarity_search(query, k=initial_top_k)
movies_list = [rec.page_content.strip('"').split()[0] for rec in recs]
movies_recs = self.movies[self.movies["uuid"].isin(movies_list)].head(initial_top_k)
if genre != "All":
movies_recs = movies_recs[movies_recs["genre"] == genre][: final_top_k]
else:
movies_recs = movies_recs.head(final_top_k)
if vibe == "Balanced":
movies_recs.sort_values(by="neutral", ascending=False, inplace=True)
elif vibe == "Happy":
movies_recs.sort_values(by="joy", ascending=False, inplace=True)
elif vibe == "Mind-Bending":
movies_recs.sort_values(by="surprise", ascending=False, inplace=True)
# elif vibe == "Rage":
# movies_recs.sort_values(by="anger", ascending=False, inplace=True)
elif vibe == "Scary":
movies_recs.sort_values(by="fear", ascending=False, inplace=True)
elif vibe == "In the feels":
movies_recs.sort_values(by="sadness", ascending=False, inplace=True)
# elif vibe == "Gruesome":
# movies_recs.sort_values(by="disgust", ascending=False, inplace=True)
return movies_recs
def recommend_movies(self, query: str, genre: str, vibe: str) -> str:
"""
Method to generate a string with the list of selected movies recommended
:param query: User query
:param genre: List of Genres available
:param vibe: List of Vibe options available
:return output: String with the list of recommended movies
"""
recommendations = self.retrieve_recommendations(query, genre, vibe)
results = []
for i in range(len(recommendations)):
row = recommendations.iloc[i]
plot_split = row["plot"].split()
truncated_plot = " ".join(plot_split[:30]) + "..."
director_split = row["director"].split(",")
if len(director_split) > 2:
directors = f"{', '.join(director_split[:-1])} and {director_split[-1]}"
elif len(director_split) == 2:
directors = "and".join(director_split)
else:
directors = row["director"]
caption = f"{i+1}. {row['title']} by {directors}: {truncated_plot}"
results.append(caption)
if len(results) == 0:
output = "Sorry, our database movies does not have recommendations for the chosen Genre and Vibe :("
else:
output = "\n\n\n".join(results)
return output
def generate_dashboard(self):
theme = gr.themes.Citrus()
with gr.Blocks(theme=theme) as dashboard:
gr.Markdown("# Get Movies Recommendations or Generate Your Own Movie Script !!!")
with gr.Tab(label="Movies Recommender"):
gr.Markdown("# Movies Recommender")
with gr.Row():
with gr.Column():
genre_dropdown = gr.Dropdown(choices=self.genres, label="Select A Genre", value="All")
with gr.Column():
vibe_dropdown = gr.Dropdown(choices=self.vibe, label="Choose Your Vibe", value="Neutral")
with gr.Row():
user_query = gr.Textbox(label="Please enter a description of the movie you would like to watch:",
placeholder="e.g. A story about love in war")
with gr.Row():
submit_button = gr.Button("Recommend")
gr.Markdown("## Recommendations")
with gr.Row():
output = gr.TextArea(interactive=False,
label="Your recommendations will be displayed below:",
autoscroll=False,
show_label=True,
show_copy_button=True, )
submit_button.click(fn=self.recommend_movies,
inputs=[user_query, genre_dropdown, vibe_dropdown],
outputs=[output], )
with gr.Tab("Movie Script Generator"):
gr.Markdown("# Movie Script Generator")
with gr.Row():
script_gen_query_textbox = gr.Textbox(label="Enter your prompt here:", lines=1,
placeholder="e.g. Generate a movie where a couple "
"discovers love during a war")
with gr.Row():
button = gr.Button("Generate")
with gr.Column():
output = gr.TextArea(interactive=False,
placeholder="Your Movie Plot will be displayed here. "
"Don't forget to invite us to your movie premier! :)",
autoscroll=False,
show_label=False,
)
button.click(fn=self.query_data, inputs=[script_gen_query_textbox], outputs=[output])
dashboard.launch(debug=True)
if __name__ == "__main__":
GradioDashboard()