chirag0107 commited on
Commit
509a17c
·
verified ·
1 Parent(s): bad0d8d

Update langchain_movie_search.py

Browse files
Files changed (1) hide show
  1. langchain_movie_search.py +37 -49
langchain_movie_search.py CHANGED
@@ -1,14 +1,18 @@
1
  import os
2
  from typing import List
 
 
3
  from dotenv import load_dotenv
4
  import pymongo
5
  from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
6
  from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
7
  from langchain.chains import create_retrieval_chain
8
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
9
  from langchain_core.prompts import PromptTemplate
10
  import gradio as gr
11
  from gradio.themes.base import Base
 
12
 
13
  __author__ = "Chirag Kamble"
14
 
@@ -24,44 +28,28 @@ class MoviesSearch:
24
  """
25
  # Load environment variables
26
  load_dotenv()
27
- transformer_model_name: str = os.getenv("TRANSFORMER_MODEL_NAME")
28
  mongodb_connection_url: str = os.getenv("MONGODB_CONNECTION_URL")
29
  mongodb_db_name: str = os.getenv("MONGODB_DB_NAME")
30
  mongodb_collection_name: str = os.getenv("MONGODB_COLLECTION_NAME")
31
- self.huggingface_repo: str = os.getenv("HF_REPO")
32
  self.huggingface_api_token: str = os.getenv("HF_TOKEN")
33
  self.huggingface_text_generation_model: str = os.getenv("HUGGINGFACE_TEXT_GENERATION_MODEL")
34
 
35
  # Setup MongoDB connection
36
- self.client: pymongo.synchronous.mongo_client.MongoClient = pymongo.MongoClient(mongodb_connection_url,
37
- serverSelectionTimeoutMS=60000,
38
- tls=True,
39
- connect=False,
40
- tlsAllowInvalidCertificates=True,
41
- directConnection=False,
42
- maxPoolSize=100,
43
- maxIdleTimeMS=60000,
44
- waitQueueTimeoutMS=60000,
45
- connectTimeoutMS=60000,
46
- retryWrites=True,
47
- retryReads=True,
48
- )
49
  db: str = mongodb_db_name
50
  collection_name: str = mongodb_collection_name
51
  self.langchain_movies_collection: pymongo.synchronous.collection.Collection = self.client[db][collection_name]
52
 
53
  self.sample_movies_collection: pymongo.synchronous.collection.Collection = self.client.sample_mflix.movies
54
 
55
- self.hf_plot_embedding = HuggingFaceEmbeddings(
56
- model_name=transformer_model_name,
57
- show_progress=True,
58
- )
59
 
60
  self.retrieve_vector_store = MongoDBAtlasVectorSearch(collection=self.langchain_movies_collection,
61
  embedding=self.hf_plot_embedding,
62
  embedding_key="embedding",
63
- index_name="langchain_movies_vector_index",
64
- text_key="text",
65
  )
66
 
67
  def generate_insert_embeddings(self):
@@ -88,48 +76,48 @@ class MoviesSearch:
88
  hf_llm: HuggingFaceEndpoint = HuggingFaceEndpoint(
89
  repo_id=self.huggingface_text_generation_model,
90
  huggingfacehub_api_token=self.huggingface_api_token,
91
- # temperature=0.1,
92
  task="text-generation",
93
- # max_new_tokens=100,
94
- verbose=True,
95
- return_full_text=True,
 
96
  )
97
 
98
- retriever = self.retrieve_vector_store.as_retriever()
99
-
100
- prompt = PromptTemplate.from_template(template="{context}", template_format="f-string")
101
- combine_docs = create_stuff_documents_chain(llm=hf_llm, prompt=prompt, )
102
-
103
- retrival_chain = create_retrieval_chain(retriever=retriever, combine_docs_chain=combine_docs)
104
- hf_llm_retriever_output = retrival_chain.invoke({"input": query})
105
 
106
- llm_answer = hf_llm_retriever_output.get("answer")
 
 
 
107
 
108
  return llm_answer
109
 
110
  def run_website(self):
111
- with gr.Blocks(theme=Base(), title="Movie plot search App using Vector Search + RAG") as v_search:
112
- gr.Markdown("Movie plot search App using Vector Search + RAG")
113
- textbox = gr.Textbox(label="Enter your question:", lines=1)
 
 
 
114
  with gr.Row():
115
- button = gr.Button("Submit", variant="primary")
116
  with gr.Column():
117
- output = gr.Textbox(lines=1, autoscroll=False, interactive=False,
118
- label="""Output generated by chaining Atlas Vector Search with Langchain's RAG""",)
 
 
 
 
119
 
120
  button.click(fn=self.query_data, inputs=textbox, outputs=[output])
121
 
122
- v_search.launch(share=True)
123
 
124
- def close_client(self):
125
- self.client.close()
126
 
127
-
128
- def gradio_interface(cmd=None):
129
  movie_search = MoviesSearch()
130
- # movie_search.generate_insert_embeddings()
131
  movie_search.run_website()
132
-
133
-
134
- if __name__ == "__main__":
135
- gradio_interface()
 
1
  import os
2
  from typing import List
3
+ import argparse
4
+ import certifi
5
  from dotenv import load_dotenv
6
  import pymongo
7
  from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
8
  from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
9
  from langchain.chains import create_retrieval_chain
10
  from langchain.chains.combine_documents import create_stuff_documents_chain
11
+ from langchain_core.documents import Document
12
  from langchain_core.prompts import PromptTemplate
13
  import gradio as gr
14
  from gradio.themes.base import Base
15
+ from flask import Flask
16
 
17
  __author__ = "Chirag Kamble"
18
 
 
28
  """
29
  # Load environment variables
30
  load_dotenv()
31
+
32
  mongodb_connection_url: str = os.getenv("MONGODB_CONNECTION_URL")
33
  mongodb_db_name: str = os.getenv("MONGODB_DB_NAME")
34
  mongodb_collection_name: str = os.getenv("MONGODB_COLLECTION_NAME")
 
35
  self.huggingface_api_token: str = os.getenv("HF_TOKEN")
36
  self.huggingface_text_generation_model: str = os.getenv("HUGGINGFACE_TEXT_GENERATION_MODEL")
37
 
38
  # Setup MongoDB connection
39
+ self.client: pymongo.synchronous.mongo_client.MongoClient = pymongo.MongoClient(mongodb_connection_url)
 
 
 
 
 
 
 
 
 
 
 
 
40
  db: str = mongodb_db_name
41
  collection_name: str = mongodb_collection_name
42
  self.langchain_movies_collection: pymongo.synchronous.collection.Collection = self.client[db][collection_name]
43
 
44
  self.sample_movies_collection: pymongo.synchronous.collection.Collection = self.client.sample_mflix.movies
45
 
46
+ self.hf_plot_embedding = HuggingFaceEmbeddings()
 
 
 
47
 
48
  self.retrieve_vector_store = MongoDBAtlasVectorSearch(collection=self.langchain_movies_collection,
49
  embedding=self.hf_plot_embedding,
50
  embedding_key="embedding",
51
+ index_name="movies_data_12k_vector_index",
52
+ text_key="uuid_plot",
53
  )
54
 
55
  def generate_insert_embeddings(self):
 
76
  hf_llm: HuggingFaceEndpoint = HuggingFaceEndpoint(
77
  repo_id=self.huggingface_text_generation_model,
78
  huggingfacehub_api_token=self.huggingface_api_token,
79
+ temperature=0.1,
80
  task="text-generation",
81
+ repetition_penalty=1.03,
82
+ top_k=10,
83
+ top_p=0.95,
84
+ typical_p=0.95,
85
  )
86
 
87
+ prompt = PromptTemplate.from_template(
88
+ template="Generate a movie plot based on the below description.\nBe creative but stay true to the "
89
+ "description provided.\nDescription:{context}",
90
+ )
 
 
 
91
 
92
+ formatted_prompt = prompt.format(context=query)
93
+ llm_answer = hf_llm.invoke(formatted_prompt)
94
+ llm_answer = llm_answer.split("\n", 1)[1]
95
+ print(llm_answer)
96
 
97
  return llm_answer
98
 
99
  def run_website(self):
100
+ theme = gr.themes.Ocean()
101
+ with gr.Blocks(theme=theme, title="Movie Plot Generation using Vector Search + RAG") as dashboard:
102
+ gr.Markdown("# Generate Movie Plot using Vector Search + RAG")
103
+ with gr.Row():
104
+ textbox = gr.Textbox(label="Enter your prompt here:", lines=1,
105
+ placeholder="e.g. Generate a movie of a couple discovering love in war")
106
  with gr.Row():
107
+ button = gr.Button("Generate")
108
  with gr.Column():
109
+ output = gr.Textbox(interactive=False,
110
+ label="Here is a Movie Plot for you. Don't forget to invite us to the premier!",
111
+ autoscroll=False,
112
+ show_label=True,
113
+ show_copy_button=True,
114
+ )
115
 
116
  button.click(fn=self.query_data, inputs=textbox, outputs=[output])
117
 
118
+ dashboard.launch(debug=True)
119
 
 
 
120
 
121
+ if __name__ == "__main__":
 
122
  movie_search = MoviesSearch()
 
123
  movie_search.run_website()