File size: 8,037 Bytes
6ba8078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import pandas as pd
import gradio as gr
import os
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate

from connect import DBConnect

__author__ = "Chirag Kamble"


class GradioDashboard:
    """
    Class to generate a simple Gradio Dashboard
    """
    def __init__(self):
        """
        Initialize variable instances and methods
        """
        load_dotenv()

        self.mongodb_vector_store, self.movies = DBConnect().connect_db()
        self.genres = ["All"] + sorted(self.movies["genre"].apply(lambda x: x.capitalize()).unique())
        self.vibe = ["Neutral", "Happy", "Mind-Bending", "Scary", "In the feels..."]
        self.huggingface_text_generation_model: str = os.getenv("HUGGINGFACE_TEXT_GENERATION_MODEL")
        self.huggingface_api_token: str = os.getenv("HF_TOKEN")

        self.generate_dashboard()

    def query_data(self, query: str):
        """
        Movie Script Generation method to Query data from Atlas Vector Search
        :param query: A user query to search
        :return llm_answer: String answer generated by the LLM
        """
        if len(query) == 0:
            raise gr.Error("Enter a prompt to generate a response !", duration=5)

        hf_llm: HuggingFaceEndpoint = HuggingFaceEndpoint(
            repo_id=self.huggingface_text_generation_model,
            huggingfacehub_api_token=self.huggingface_api_token,
            temperature=0.1,
            task="text-generation",
            repetition_penalty=1.03,
            top_k=10,
            top_p=0.95,
            typical_p=0.95,
        )

        prompt = PromptTemplate.from_template(
            template="Generate a movie plot based on the below user query.\nBe creative but stay true to the "
                     "description provided.\nUser Query:{context}",
        )

        formatted_prompt = prompt.format(context=query)
        llm_answer = hf_llm.invoke(formatted_prompt)
        llm_answer = llm_answer.split("\n", 1)[1]

        return llm_answer

    def retrieve_recommendations(self, query, genre, vibe, initial_top_k=50, final_top_k=10) -> pd.DataFrame:
        """
        Method to retrieve the recommendation from the vector database
        :param query: User query
        :param genre: List of genres available
        :param vibe: List of vibes options available
        :param initial_top_k: Initial number of searched and selected movies
        :param final_top_k: Final number of recommended movies

        :return movies_recs: Final Dataframe of recommended movies
        """
        recs = self.mongodb_vector_store.similarity_search(query, k=initial_top_k)
        movies_list = [rec.page_content.strip('"').split()[0] for rec in recs]
        movies_recs = self.movies[self.movies["uuid"].isin(movies_list)].head(initial_top_k)

        if genre != "All":
            movies_recs = movies_recs[movies_recs["genre"] == genre][: final_top_k]
        else:
            movies_recs = movies_recs.head(final_top_k)

        if vibe == "Balanced":
            movies_recs.sort_values(by="neutral", ascending=False, inplace=True)
        elif vibe == "Happy":
            movies_recs.sort_values(by="joy", ascending=False, inplace=True)
        elif vibe == "Mind-Bending":
            movies_recs.sort_values(by="surprise", ascending=False, inplace=True)
        # elif vibe == "Rage":
        #     movies_recs.sort_values(by="anger", ascending=False, inplace=True)
        elif vibe == "Scary":
            movies_recs.sort_values(by="fear", ascending=False, inplace=True)
        elif vibe == "In the feels":
            movies_recs.sort_values(by="sadness", ascending=False, inplace=True)
        # elif vibe == "Gruesome":
        #     movies_recs.sort_values(by="disgust", ascending=False, inplace=True)

        return movies_recs

    def recommend_movies(self, query: str, genre: str, vibe: str) -> str:
        """
        Method to generate a string with the list of selected movies recommended
        :param query: User query
        :param genre: List of Genres available
        :param vibe: List of Vibe options available

        :return output: String with the list of recommended movies
        """
        recommendations = self.retrieve_recommendations(query, genre, vibe)

        results = []
        for i in range(len(recommendations)):
            row = recommendations.iloc[i]

            plot_split = row["plot"].split()
            truncated_plot = " ".join(plot_split[:30]) + "..."

            director_split = row["director"].split(",")

            if len(director_split) > 2:
                directors = f"{', '.join(director_split[:-1])} and {director_split[-1]}"
            elif len(director_split) == 2:
                directors = "and".join(director_split)
            else:
                directors = row["director"]

            caption = f"{i+1}. {row['title']} by {directors}: {truncated_plot}"

            results.append(caption)

        if len(results) == 0:
            output = "Sorry, our database movies does not have recommendations for the chosen Genre and Vibe :("
        else:
            output = "\n\n\n".join(results)

        return output

    def generate_dashboard(self):
        theme = gr.themes.Citrus()
        with gr.Blocks(theme=theme) as dashboard:
            gr.Markdown("# Get Movies Recommendations or Generate Your Own Movie Script !!!")
            with gr.Tab(label="Movies Recommender"):
                gr.Markdown("# Movies Recommender")

                with gr.Row():
                    with gr.Column():
                        genre_dropdown = gr.Dropdown(choices=self.genres, label="Select A Genre", value="All")
                    with gr.Column():
                        vibe_dropdown = gr.Dropdown(choices=self.vibe, label="Choose Your Vibe", value="Neutral")

                with gr.Row():
                    user_query = gr.Textbox(label="Please enter a description of the movie you would like to watch:",
                                            placeholder="e.g. A story about love in war")

                with gr.Row():
                    submit_button = gr.Button("Recommend")

                gr.Markdown("## Recommendations")

                with gr.Row():
                    output = gr.TextArea(interactive=False,
                                        label="Your recommendations will be displayed below:",
                                        autoscroll=False,
                                        show_label=True,
                                        show_copy_button=True, )

                submit_button.click(fn=self.recommend_movies,
                                    inputs=[user_query, genre_dropdown, vibe_dropdown],
                                    outputs=[output], )

            with gr.Tab("Movie Script Generator"):
                gr.Markdown("# Movie Script Generator")

                with gr.Row():
                    script_gen_query_textbox = gr.Textbox(label="Enter your prompt here:", lines=1,
                                                          placeholder="e.g. Generate a movie where a couple "
                                                                      "discovers love during a war")

                with gr.Row():
                    button = gr.Button("Generate")

                with gr.Column():
                    output = gr.TextArea(interactive=False,
                                         placeholder="Your Movie Plot will be displayed here. "
                                                     "Don't forget to invite us to your movie premier! :)",
                                         autoscroll=False,
                                         show_label=False,
                                         )

                button.click(fn=self.query_data, inputs=[script_gen_query_textbox], outputs=[output])

        dashboard.launch(debug=True)


if __name__ == "__main__":
    GradioDashboard()