File size: 8,037 Bytes
6ba8078 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import pandas as pd
import gradio as gr
import os
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from connect import DBConnect
__author__ = "Chirag Kamble"
class GradioDashboard:
"""
Class to generate a simple Gradio Dashboard
"""
def __init__(self):
"""
Initialize variable instances and methods
"""
load_dotenv()
self.mongodb_vector_store, self.movies = DBConnect().connect_db()
self.genres = ["All"] + sorted(self.movies["genre"].apply(lambda x: x.capitalize()).unique())
self.vibe = ["Neutral", "Happy", "Mind-Bending", "Scary", "In the feels..."]
self.huggingface_text_generation_model: str = os.getenv("HUGGINGFACE_TEXT_GENERATION_MODEL")
self.huggingface_api_token: str = os.getenv("HF_TOKEN")
self.generate_dashboard()
def query_data(self, query: str):
"""
Movie Script Generation method to Query data from Atlas Vector Search
:param query: A user query to search
:return llm_answer: String answer generated by the LLM
"""
if len(query) == 0:
raise gr.Error("Enter a prompt to generate a response !", duration=5)
hf_llm: HuggingFaceEndpoint = HuggingFaceEndpoint(
repo_id=self.huggingface_text_generation_model,
huggingfacehub_api_token=self.huggingface_api_token,
temperature=0.1,
task="text-generation",
repetition_penalty=1.03,
top_k=10,
top_p=0.95,
typical_p=0.95,
)
prompt = PromptTemplate.from_template(
template="Generate a movie plot based on the below user query.\nBe creative but stay true to the "
"description provided.\nUser Query:{context}",
)
formatted_prompt = prompt.format(context=query)
llm_answer = hf_llm.invoke(formatted_prompt)
llm_answer = llm_answer.split("\n", 1)[1]
return llm_answer
def retrieve_recommendations(self, query, genre, vibe, initial_top_k=50, final_top_k=10) -> pd.DataFrame:
"""
Method to retrieve the recommendation from the vector database
:param query: User query
:param genre: List of genres available
:param vibe: List of vibes options available
:param initial_top_k: Initial number of searched and selected movies
:param final_top_k: Final number of recommended movies
:return movies_recs: Final Dataframe of recommended movies
"""
recs = self.mongodb_vector_store.similarity_search(query, k=initial_top_k)
movies_list = [rec.page_content.strip('"').split()[0] for rec in recs]
movies_recs = self.movies[self.movies["uuid"].isin(movies_list)].head(initial_top_k)
if genre != "All":
movies_recs = movies_recs[movies_recs["genre"] == genre][: final_top_k]
else:
movies_recs = movies_recs.head(final_top_k)
if vibe == "Balanced":
movies_recs.sort_values(by="neutral", ascending=False, inplace=True)
elif vibe == "Happy":
movies_recs.sort_values(by="joy", ascending=False, inplace=True)
elif vibe == "Mind-Bending":
movies_recs.sort_values(by="surprise", ascending=False, inplace=True)
# elif vibe == "Rage":
# movies_recs.sort_values(by="anger", ascending=False, inplace=True)
elif vibe == "Scary":
movies_recs.sort_values(by="fear", ascending=False, inplace=True)
elif vibe == "In the feels":
movies_recs.sort_values(by="sadness", ascending=False, inplace=True)
# elif vibe == "Gruesome":
# movies_recs.sort_values(by="disgust", ascending=False, inplace=True)
return movies_recs
def recommend_movies(self, query: str, genre: str, vibe: str) -> str:
"""
Method to generate a string with the list of selected movies recommended
:param query: User query
:param genre: List of Genres available
:param vibe: List of Vibe options available
:return output: String with the list of recommended movies
"""
recommendations = self.retrieve_recommendations(query, genre, vibe)
results = []
for i in range(len(recommendations)):
row = recommendations.iloc[i]
plot_split = row["plot"].split()
truncated_plot = " ".join(plot_split[:30]) + "..."
director_split = row["director"].split(",")
if len(director_split) > 2:
directors = f"{', '.join(director_split[:-1])} and {director_split[-1]}"
elif len(director_split) == 2:
directors = "and".join(director_split)
else:
directors = row["director"]
caption = f"{i+1}. {row['title']} by {directors}: {truncated_plot}"
results.append(caption)
if len(results) == 0:
output = "Sorry, our database movies does not have recommendations for the chosen Genre and Vibe :("
else:
output = "\n\n\n".join(results)
return output
def generate_dashboard(self):
theme = gr.themes.Citrus()
with gr.Blocks(theme=theme) as dashboard:
gr.Markdown("# Get Movies Recommendations or Generate Your Own Movie Script !!!")
with gr.Tab(label="Movies Recommender"):
gr.Markdown("# Movies Recommender")
with gr.Row():
with gr.Column():
genre_dropdown = gr.Dropdown(choices=self.genres, label="Select A Genre", value="All")
with gr.Column():
vibe_dropdown = gr.Dropdown(choices=self.vibe, label="Choose Your Vibe", value="Neutral")
with gr.Row():
user_query = gr.Textbox(label="Please enter a description of the movie you would like to watch:",
placeholder="e.g. A story about love in war")
with gr.Row():
submit_button = gr.Button("Recommend")
gr.Markdown("## Recommendations")
with gr.Row():
output = gr.TextArea(interactive=False,
label="Your recommendations will be displayed below:",
autoscroll=False,
show_label=True,
show_copy_button=True, )
submit_button.click(fn=self.recommend_movies,
inputs=[user_query, genre_dropdown, vibe_dropdown],
outputs=[output], )
with gr.Tab("Movie Script Generator"):
gr.Markdown("# Movie Script Generator")
with gr.Row():
script_gen_query_textbox = gr.Textbox(label="Enter your prompt here:", lines=1,
placeholder="e.g. Generate a movie where a couple "
"discovers love during a war")
with gr.Row():
button = gr.Button("Generate")
with gr.Column():
output = gr.TextArea(interactive=False,
placeholder="Your Movie Plot will be displayed here. "
"Don't forget to invite us to your movie premier! :)",
autoscroll=False,
show_label=False,
)
button.click(fn=self.query_data, inputs=[script_gen_query_textbox], outputs=[output])
dashboard.launch(debug=True)
if __name__ == "__main__":
GradioDashboard()
|