zoya-hammad commited on
Commit
86023ac
·
1 Parent(s): a97c5ac

Add application file

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # imports
2
+
3
+ import os
4
+ from dotenv import load_dotenv
5
+ import gradio as gr
6
+
7
+ # imports for langchain, plotly and Chroma
8
+
9
+ from langchain.document_loaders import DirectoryLoader, TextLoader
10
+ from langchain.text_splitter import CharacterTextSplitter
11
+ from langchain.schema import Document
12
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ from langchain_chroma import Chroma
15
+ from langchain.memory import ConversationBufferMemory
16
+ from langchain.chains import ConversationalRetrievalChain
17
+ import numpy as np
18
+ from sklearn.manifold import TSNE
19
+ import plotly.graph_objects as go
20
+ import plotly.express as px
21
+ import matplotlib.pyplot as plt
22
+ from random import randint
23
+ import shutil
24
+
25
+ MODEL = "gpt-4o-mini"
26
+ db_name = "vector_db"
27
+
28
+ # Load environment variables in a file called .env
29
+
30
+ load_dotenv(override=True)
31
+ os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY')
32
+
33
+ folder = "my-knowledge-base/"
34
+ db_name = "vectorstore_db"
35
+
36
+ def process_files(files):
37
+ os.makedirs(folder, exist_ok=True)
38
+
39
+ processed_files = []
40
+ for file in files:
41
+ file_path = os.path.join(folder, os.path.basename(file)) # Get filename
42
+ shutil.copy(file, file_path)
43
+ processed_files.append(os.path.basename(file))
44
+
45
+ # Load documents using LangChain's DirectoryLoader
46
+ text_loader_kwargs = {'autodetect_encoding': True}
47
+ loader = DirectoryLoader(folder, glob="**/*.md", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)
48
+ folder_docs = loader.load()
49
+
50
+ # Assign filenames as metadata
51
+ for doc in folder_docs:
52
+ filename_md = os.path.basename(doc.metadata["source"])
53
+ filename, _ = os.path.splitext(filename_md)
54
+ doc.metadata["filename"] = filename
55
+
56
+ documents = folder_docs
57
+
58
+ # Split documents into chunks
59
+ text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=200)
60
+ chunks = text_splitter.split_documents(documents)
61
+
62
+ # Initialize embeddings
63
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
64
+
65
+ # Delete previous vectorstore
66
+ if os.path.exists(db_name):
67
+ Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()
68
+
69
+ # Store in ChromaDB
70
+ vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name)
71
+
72
+ # Retrieve results
73
+ collection = vectorstore._collection
74
+ result = collection.get(include=['embeddings', 'documents', 'metadatas'])
75
+
76
+ llm = ChatOpenAI(temperature=0.7, model_name=MODEL)
77
+ memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
78
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 35})
79
+ global conversation_chain
80
+ conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)
81
+
82
+ processed_text = "**Processed Files:**\n\n" + "\n".join(f"- {file}" for file in processed_files)
83
+ return result, processed_text
84
+
85
+ def random_color():
86
+ return f"rgb({randint(0,255)},{randint(0,255)},{randint(0,255)})"
87
+
88
+ def show_embeddings_2d(result):
89
+ vectors = np.array(result['embeddings'])
90
+ documents = result['documents']
91
+ metadatas = result['metadatas']
92
+ filenames = [metadata['filename'] for metadata in metadatas]
93
+ filenames_unique = sorted(set(filenames))
94
+
95
+ # color assignment
96
+ color_map = {name: random_color() for name in filenames_unique}
97
+ colors = [color_map[name] for name in filenames]
98
+
99
+ tsne = TSNE(n_components=2, random_state=42,perplexity=4)
100
+ reduced_vectors = tsne.fit_transform(vectors)
101
+
102
+ # Create the 2D scatter plot
103
+ fig = go.Figure(data=[go.Scatter(
104
+ x=reduced_vectors[:, 0],
105
+ y=reduced_vectors[:, 1],
106
+ mode='markers',
107
+ marker=dict(size=5,color=colors, opacity=0.8),
108
+ text=[f"Type: {t}<br>Text: {d[:100]}..." for t, d in zip(filenames, documents)],
109
+ hoverinfo='text'
110
+ )])
111
+
112
+ fig.update_layout(
113
+ title='2D Chroma Vector Store Visualization',
114
+ scene=dict(xaxis_title='x',yaxis_title='y'),
115
+ width=800,
116
+ height=600,
117
+ margin=dict(r=20, b=10, l=10, t=40)
118
+ )
119
+
120
+ return fig
121
+
122
+ def show_embeddings_3d(result):
123
+ vectors = np.array(result['embeddings'])
124
+ documents = result['documents']
125
+ metadatas = result['metadatas']
126
+ filenames = [metadata['filename'] for metadata in metadatas]
127
+ filenames_unique = sorted(set(filenames))
128
+
129
+ # color assignment
130
+ color_map = {name: random_color() for name in filenames_unique}
131
+ colors = [color_map[name] for name in filenames]
132
+
133
+ tsne = TSNE(n_components=3, random_state=42)
134
+ reduced_vectors = tsne.fit_transform(vectors)
135
+
136
+ fig = go.Figure(data=[go.Scatter3d(
137
+ x=reduced_vectors[:, 0],
138
+ y=reduced_vectors[:, 1],
139
+ z=reduced_vectors[:, 2],
140
+ mode='markers',
141
+ marker=dict(size=5, color=colors, opacity=0.8),
142
+ text=[f"Type: {t}<br>Text: {d[:100]}..." for t, d in zip(filenames, documents)],
143
+ hoverinfo='text'
144
+ )])
145
+
146
+ fig.update_layout(
147
+ title='3D Chroma Vector Store Visualization',
148
+ scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),
149
+ width=900,
150
+ height=700,
151
+ margin=dict(r=20, b=10, l=10, t=40)
152
+ )
153
+
154
+ return fig
155
+
156
+ def chat(question, history):
157
+ result = conversation_chain.invoke({"question": question})
158
+ return result["answer"]
159
+
160
+ def visualise_data(result):
161
+ fig_2d = show_embeddings_2d(result)
162
+ fig_3d = show_embeddings_3d(result)
163
+ return fig_2d,fig_3d
164
+
165
+ css = """
166
+ .btn {background-color: #1d53d1;}
167
+ """
168
+
169
+ with gr.Blocks(css=css) as ui:
170
+ gr.Markdown("# Markdown-Based Q&A with Visualization")
171
+ with gr.Row():
172
+ file_input = gr.Files(file_types=[".md"], label="Upload Markdown Files")
173
+ with gr.Column(scale=1):
174
+ processed_output = gr.Markdown("Progress")
175
+ with gr.Row():
176
+ process_btn = gr.Button("Process Files",elem_classes=["btn"])
177
+ with gr.Row():
178
+ question = gr.Textbox(label="Chat ", lines=10)
179
+ answer = gr.Markdown(label= "Response")
180
+ with gr.Row():
181
+ question_btn = gr.Button("Ask a Question",elem_classes=["btn"])
182
+ clear_btn = gr.Button("Clear Output",elem_classes=["btn"])
183
+ with gr.Row():
184
+ plot_2d = gr.Plot(label="2D Visualization")
185
+ plot_3d = gr.Plot(label="3D Visualization")
186
+ with gr.Row():
187
+ visualise_btn = gr.Button("Visualise Data",elem_classes=["btn"])
188
+
189
+ result = gr.State([])
190
+ # Action: When button is clicked, process files and update visualization
191
+ clear_btn.click(fn=lambda:("", ""), inputs=[],outputs=[question, answer])
192
+ process_btn.click(process_files, inputs=[file_input], outputs=[result,processed_output])
193
+ question_btn.click(chat, inputs=[question], outputs= [answer])
194
+ visualise_btn.click(visualise_data, inputs=[result], outputs=[plot_2d,plot_3d])
195
+
196
+ # Launch Gradio app
197
+ ui.launch(inbrowser=True)