|
|
import pandas as pd |
|
|
import gradio as gr |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from langchain_huggingface import HuggingFaceEndpoint |
|
|
from langchain_core.prompts import PromptTemplate |
|
|
|
|
|
from connect import DBConnect |
|
|
|
|
|
__author__ = "Chirag Kamble" |
|
|
|
|
|
|
|
|
class GradioDashboard: |
|
|
""" |
|
|
Class to generate a simple Gradio Dashboard |
|
|
""" |
|
|
def __init__(self): |
|
|
""" |
|
|
Initialize variable instances and methods |
|
|
""" |
|
|
load_dotenv() |
|
|
|
|
|
self.mongodb_vector_store, self.movies = DBConnect().connect_db() |
|
|
self.genres = ["All"] + sorted(self.movies["genre"].apply(lambda x: x.capitalize()).unique()) |
|
|
self.vibe = ["Neutral", "Happy", "Mind-Bending", "Scary", "In the feels..."] |
|
|
self.huggingface_text_generation_model: str = os.getenv("HUGGINGFACE_TEXT_GENERATION_MODEL") |
|
|
self.huggingface_api_token: str = os.getenv("HF_TOKEN") |
|
|
|
|
|
self.generate_dashboard() |
|
|
|
|
|
def query_data(self, query: str): |
|
|
""" |
|
|
Movie Script Generation method to Query data from Atlas Vector Search |
|
|
:param query: A user query to search |
|
|
:return llm_answer: String answer generated by the LLM |
|
|
""" |
|
|
if len(query) == 0: |
|
|
raise gr.Error("Enter a prompt to generate a response !", duration=5) |
|
|
|
|
|
hf_llm: HuggingFaceEndpoint = HuggingFaceEndpoint( |
|
|
repo_id=self.huggingface_text_generation_model, |
|
|
huggingfacehub_api_token=self.huggingface_api_token, |
|
|
temperature=0.1, |
|
|
task="text-generation", |
|
|
repetition_penalty=1.03, |
|
|
top_k=10, |
|
|
top_p=0.95, |
|
|
typical_p=0.95, |
|
|
) |
|
|
|
|
|
prompt = PromptTemplate.from_template( |
|
|
template="Generate a movie plot based on the below user query.\nBe creative but stay true to the " |
|
|
"description provided.\nUser Query:{context}", |
|
|
) |
|
|
|
|
|
formatted_prompt = prompt.format(context=query) |
|
|
llm_answer = hf_llm.invoke(formatted_prompt) |
|
|
llm_answer = llm_answer.split("\n", 1)[1] |
|
|
|
|
|
return llm_answer |
|
|
|
|
|
def retrieve_recommendations(self, query, genre, vibe, initial_top_k=50, final_top_k=10) -> pd.DataFrame: |
|
|
""" |
|
|
Method to retrieve the recommendation from the vector database |
|
|
:param query: User query |
|
|
:param genre: List of genres available |
|
|
:param vibe: List of vibes options available |
|
|
:param initial_top_k: Initial number of searched and selected movies |
|
|
:param final_top_k: Final number of recommended movies |
|
|
|
|
|
:return movies_recs: Final Dataframe of recommended movies |
|
|
""" |
|
|
recs = self.mongodb_vector_store.similarity_search(query, k=initial_top_k) |
|
|
movies_list = [rec.page_content.strip('"').split()[0] for rec in recs] |
|
|
movies_recs = self.movies[self.movies["uuid"].isin(movies_list)].head(initial_top_k) |
|
|
|
|
|
if genre != "All": |
|
|
movies_recs = movies_recs[movies_recs["genre"] == genre][: final_top_k] |
|
|
else: |
|
|
movies_recs = movies_recs.head(final_top_k) |
|
|
|
|
|
if vibe == "Balanced": |
|
|
movies_recs.sort_values(by="neutral", ascending=False, inplace=True) |
|
|
elif vibe == "Happy": |
|
|
movies_recs.sort_values(by="joy", ascending=False, inplace=True) |
|
|
elif vibe == "Mind-Bending": |
|
|
movies_recs.sort_values(by="surprise", ascending=False, inplace=True) |
|
|
|
|
|
|
|
|
elif vibe == "Scary": |
|
|
movies_recs.sort_values(by="fear", ascending=False, inplace=True) |
|
|
elif vibe == "In the feels": |
|
|
movies_recs.sort_values(by="sadness", ascending=False, inplace=True) |
|
|
|
|
|
|
|
|
|
|
|
return movies_recs |
|
|
|
|
|
def recommend_movies(self, query: str, genre: str, vibe: str) -> str: |
|
|
""" |
|
|
Method to generate a string with the list of selected movies recommended |
|
|
:param query: User query |
|
|
:param genre: List of Genres available |
|
|
:param vibe: List of Vibe options available |
|
|
|
|
|
:return output: String with the list of recommended movies |
|
|
""" |
|
|
recommendations = self.retrieve_recommendations(query, genre, vibe) |
|
|
|
|
|
results = [] |
|
|
for i in range(len(recommendations)): |
|
|
row = recommendations.iloc[i] |
|
|
|
|
|
plot_split = row["plot"].split() |
|
|
truncated_plot = " ".join(plot_split[:30]) + "..." |
|
|
|
|
|
director_split = row["director"].split(",") |
|
|
|
|
|
if len(director_split) > 2: |
|
|
directors = f"{', '.join(director_split[:-1])} and {director_split[-1]}" |
|
|
elif len(director_split) == 2: |
|
|
directors = "and".join(director_split) |
|
|
else: |
|
|
directors = row["director"] |
|
|
|
|
|
caption = f"{i+1}. {row['title']} by {directors}: {truncated_plot}" |
|
|
|
|
|
results.append(caption) |
|
|
|
|
|
if len(results) == 0: |
|
|
output = "Sorry, our database movies does not have recommendations for the chosen Genre and Vibe :(" |
|
|
else: |
|
|
output = "\n\n\n".join(results) |
|
|
|
|
|
return output |
|
|
|
|
|
def generate_dashboard(self): |
|
|
theme = gr.themes.Citrus() |
|
|
with gr.Blocks(theme=theme) as dashboard: |
|
|
gr.Markdown("# Get Movies Recommendations or Generate Your Own Movie Script !!!") |
|
|
with gr.Tab(label="Movies Recommender"): |
|
|
gr.Markdown("# Movies Recommender") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
genre_dropdown = gr.Dropdown(choices=self.genres, label="Select A Genre", value="All") |
|
|
with gr.Column(): |
|
|
vibe_dropdown = gr.Dropdown(choices=self.vibe, label="Choose Your Vibe", value="Neutral") |
|
|
|
|
|
with gr.Row(): |
|
|
user_query = gr.Textbox(label="Please enter a description of the movie you would like to watch:", |
|
|
placeholder="e.g. A story about love in war") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_button = gr.Button("Recommend") |
|
|
|
|
|
gr.Markdown("## Recommendations") |
|
|
|
|
|
with gr.Row(): |
|
|
output = gr.TextArea(interactive=False, |
|
|
label="Your recommendations will be displayed below:", |
|
|
autoscroll=False, |
|
|
show_label=True, |
|
|
show_copy_button=True, ) |
|
|
|
|
|
submit_button.click(fn=self.recommend_movies, |
|
|
inputs=[user_query, genre_dropdown, vibe_dropdown], |
|
|
outputs=[output], ) |
|
|
|
|
|
with gr.Tab("Movie Script Generator"): |
|
|
gr.Markdown("# Movie Script Generator") |
|
|
|
|
|
with gr.Row(): |
|
|
script_gen_query_textbox = gr.Textbox(label="Enter your prompt here:", lines=1, |
|
|
placeholder="e.g. Generate a movie where a couple " |
|
|
"discovers love during a war") |
|
|
|
|
|
with gr.Row(): |
|
|
button = gr.Button("Generate") |
|
|
|
|
|
with gr.Column(): |
|
|
output = gr.TextArea(interactive=False, |
|
|
placeholder="Your Movie Plot will be displayed here. " |
|
|
"Don't forget to invite us to your movie premier! :)", |
|
|
autoscroll=False, |
|
|
show_label=False, |
|
|
) |
|
|
|
|
|
button.click(fn=self.query_data, inputs=[script_gen_query_textbox], outputs=[output]) |
|
|
|
|
|
dashboard.launch(debug=True) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
GradioDashboard() |
|
|
|