HaryaniAnjali commited on
Commit
f298aa5
·
verified ·
1 Parent(s): c4c77a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -117
app.py CHANGED
@@ -1,137 +1,227 @@
1
- import gradio as gr
2
  import logging
 
3
  from langchain.chains import ConversationalRetrievalChain
4
  from langchain_openai import ChatOpenAI
5
- from langchain.memory import ConversationBufferMemory # Using the updated memory package
6
- from langchain_community.vectorstores import Chroma # Corrected import for Chroma
7
- from langchain_openai import OpenAIEmbeddings # Updated import for OpenAIEmbeddings
8
  from langchain_community.document_loaders import WikipediaLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
- from langchain.tools import StructuredTool
11
  from langchain.callbacks.base import BaseCallbackHandler
12
 
13
- # ================================
14
- # Step 1: Setup Logging for Debugging
15
- # ================================
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- # ================================
20
- # Step 2: Load Wikipedia Data
21
- # ================================
22
- def fetch_wikipedia_content():
23
- """Fetches Wikipedia content using LangChain's WikipediaLoader."""
24
- loader = WikipediaLoader(query="Generative artificial intelligence", lang="en")
25
- documents = loader.load()
26
- return documents[0].page_content if documents else "Page not found."
27
-
28
- wiki_text = fetch_wikipedia_content()
29
-
30
- # ================================
31
- # Step 3: Process Wikipedia Text for Retrieval
32
- # ================================
33
- def process_and_store_wikipedia(text):
34
- """Splits Wikipedia content into chunks, embeds them, and stores in ChromaDB."""
35
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
36
- chunks = splitter.split_text(text)
37
-
38
- embeddings = OpenAIEmbeddings() # Using updated OpenAI embeddings
39
- vectorstore = Chroma.from_texts(chunks, embedding=embeddings, persist_directory="/home/user/chroma_db") # Ensuring persistence
40
- return vectorstore.as_retriever()
41
-
42
- retriever = process_and_store_wikipedia(wiki_text)
43
-
44
- # ================================
45
- # Step 4: Initialize Chat Model and Memory
46
- # ================================
47
- llm = ChatOpenAI(model_name="gpt-4o")
48
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) # Initialize memory for conversation history
49
-
50
- # ================================
51
- # Step 5: Create Q/A Retrieval Chain
52
- # ================================
53
- qa_chain = ConversationalRetrievalChain.from_llm(
54
- llm, retriever=retriever, memory=memory
55
- )
56
-
57
- # ================================
58
- # Step 6: Implement Chatbot Response Function with Caching
59
- # ================================
60
- def ask_with_memory(query):
61
- """Retrieves the answer from memory if available, otherwise fetches it using LangChain's Q/A chain."""
62
-
63
- # Load chat history
64
- chat_history = memory.load_memory_variables({})["chat_history"]
65
-
66
- # Check if the exact query has been answered before
67
- for i in range(len(chat_history) - 1):
68
- if chat_history[i].content == query:
69
- return chat_history[i + 1].content # Return cached answer
70
-
71
- # If not cached, process the query
72
- response = qa_chain.invoke({"question": query})["answer"]
73
-
74
- # Save query-response pair in memory
75
- memory.save_context({"question": query}, {"answer": response})
76
 
77
- return response
 
 
 
 
78
 
 
 
 
79
 
80
- # ================================
81
- # Step 7: Implement Structured Function Calling for Section Extraction
82
- # ================================
83
- def extract_section_by_query(query: str) -> str:
84
- """Finds and returns the most relevant section based on a user query using embeddings."""
85
- vector_store = retriever # Use the existing retriever
86
-
87
- # Retrieve the most relevant section
88
- retrieved_docs = vector_store.get_relevant_documents(query)
89
-
90
- if not retrieved_docs:
91
- return "Section not found."
92
-
93
- return f"Section: {retrieved_docs[0].metadata.get('title', 'Unknown')}\n\n{retrieved_docs[0].page_content}"
94
-
95
- section_extraction_tool = StructuredTool.from_function(
96
- extract_section_by_query,
97
- name="extract_section_by_query",
98
- description="Finds the most relevant Wikipedia section based on a user query using embeddings."
99
- )
100
-
101
- # ================================
102
- # Step 8: Implement Callback Logging for Debugging
103
- # ================================
104
  class LoggingCallbackHandler(BaseCallbackHandler):
105
  def on_chain_start(self, serialized, inputs, **kwargs):
106
- logger.info(f"Starting chain execution with input: {inputs}")
107
 
108
  def on_chain_end(self, outputs, **kwargs):
109
- logger.info(f"Chain execution finished. Output: {outputs}")
110
-
111
- callback_handler = LoggingCallbackHandler()
112
- qa_chain.callbacks = [callback_handler]
113
-
114
- # ================================
115
- # Step 9: Define Gradio Interface
116
- # ================================
117
- def respond(message, history, system_message, max_tokens, temperature, top_p):
118
- """
119
- Processes user query and retrieves answers from Wikipedia-based Q/A system with caching.
120
- """
121
- return ask_with_memory(message)
122
-
123
- # ================================
124
- # Step 10: Create Gradio Interface
125
- # ================================
126
- demo = gr.ChatInterface(
127
- respond,
128
- additional_inputs=[
129
- gr.Textbox(value="You are an AI expert answering questions about Generative AI.", label="System message"),
130
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
131
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
132
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
133
- ],
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  if __name__ == "__main__":
137
  demo.launch()
 
1
+ import os
2
  import logging
3
+ import gradio as gr
4
  from langchain.chains import ConversationalRetrievalChain
5
  from langchain_openai import ChatOpenAI
6
+ from langchain.memory import ConversationBufferMemory
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_openai import OpenAIEmbeddings
9
  from langchain_community.document_loaders import WikipediaLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
11
  from langchain.callbacks.base import BaseCallbackHandler
12
 
13
+ # Setup logging
 
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Memory cache for storing answers
18
+ class MemoryCache:
19
+ def __init__(self):
20
+ self.cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def get(self, query: str):
23
+ if query in self.cache:
24
+ logger.info(f"Cache hit: {query}")
25
+ return self.cache.get(query)
26
+ return None
27
 
28
+ def set(self, query: str, response: str):
29
+ logger.info(f"Saving to cache: {query}")
30
+ self.cache[query] = response
31
 
32
+ # Callback handler for logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class LoggingCallbackHandler(BaseCallbackHandler):
34
  def on_chain_start(self, serialized, inputs, **kwargs):
35
+ logger.info(f"Chain start. Inputs: {inputs}")
36
 
37
  def on_chain_end(self, outputs, **kwargs):
38
+ logger.info(f"Chain end. Outputs: {outputs}")
39
+
40
+ def on_retriever_start(self, *args, **kwargs):
41
+ logger.info("Retrieval start.")
42
+
43
+ def on_retriever_end(self, *args, **kwargs):
44
+ logger.info("Retrieval end.")
45
+
46
+ def on_llm_start(self, *args, **kwargs):
47
+ logger.info("LLM start.")
48
+
49
+ def on_llm_end(self, result, *args, **kwargs):
50
+ try:
51
+ final_text = result.generations[0][0].text
52
+ logger.info(f"LLM end. Text: {final_text}")
53
+ except Exception as e:
54
+ logger.error(f"LLM error: {e}")
55
+
56
+ class GenAIQASystem:
57
+ def __init__(self):
58
+ self.cache = MemoryCache()
59
+ self.callback_handler = LoggingCallbackHandler()
60
+ self.content = None
61
+ self.qa_chain = None
62
+ self.memory = None
63
+ self.wiki_loaded = False
64
+ self.api_key_set = False
65
+
66
+ def set_api_key(self, api_key):
67
+ if not api_key:
68
+ return "Please provide a valid API key."
69
+
70
+ try:
71
+ os.environ["OPENAI_API_KEY"] = api_key
72
+ # Test if API key works
73
+ embeddings = OpenAIEmbeddings()
74
+ embeddings.embed_query("Test")
75
+ self.api_key_set = True
76
+ return "API key set successfully!"
77
+ except Exception as e:
78
+ logger.error(f"API key error: {e}")
79
+ return f"Error setting API key: {str(e)}"
80
+
81
+ def load_wikipedia(self):
82
+ if not self.api_key_set:
83
+ return "Please set your OpenAI API key first."
84
+
85
+ if self.wiki_loaded:
86
+ return "Wikipedia content already loaded."
87
+
88
+ try:
89
+ logger.info("Loading Wikipedia content for Generative artificial intelligence")
90
+
91
+ # Load Wikipedia content
92
+ loader = WikipediaLoader(query="Generative artificial intelligence", lang="en")
93
+ documents = loader.load()
94
+ self.content = documents[0].page_content
95
+
96
+ # Split content into chunks
97
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
98
+ chunks = text_splitter.split_text(self.content)
99
+
100
+ # Create vector store
101
+ embeddings = OpenAIEmbeddings()
102
+ vectorstore = FAISS.from_texts(chunks, embeddings)
103
+
104
+ # Initialize memory
105
+ self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
106
+
107
+ # Create QA Chain
108
+ llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
109
+ self.qa_chain = ConversationalRetrievalChain.from_llm(
110
+ llm,
111
+ retriever=vectorstore.as_retriever(),
112
+ memory=self.memory,
113
+ callbacks=[self.callback_handler]
114
+ )
115
+
116
+ self.wiki_loaded = True
117
+ return "Wikipedia content loaded successfully!"
118
+ except Exception as e:
119
+ logger.error(f"Error loading Wikipedia: {e}")
120
+ return f"Error loading Wikipedia: {str(e)}"
121
+
122
+ def extract_section(self, query: str):
123
+ """Extracts a specific section from the Wikipedia content."""
124
+ if not self.content:
125
+ return None
126
+
127
+ query_lower = query.lower()
128
+ content_lower = self.content.lower()
129
+
130
+ # Dictionary of section headers to look for
131
+ sections = {
132
+ "early history": "== early history ==",
133
+ "generative models": "== generative models ==",
134
+ "academic artificial intelligence": "== academic artificial intelligence =="
135
+ }
136
+
137
+ # Check if query matches any section
138
+ for key, header in sections.items():
139
+ if key in query_lower:
140
+ start_index = content_lower.find(header)
141
+ if start_index != -1:
142
+ logger.info(f"Found header: {header}")
143
+ end_index = self.content.find("\n==", start_index + len(header))
144
+ section_text = self.content[start_index:end_index].strip() if end_index != -1 else self.content[start_index:].strip()
145
+ return section_text
146
+
147
+ return None
148
+
149
+ def process_query(self, query):
150
+ if not self.api_key_set:
151
+ return "Please set your OpenAI API key in the Settings tab first."
152
+
153
+ if not self.wiki_loaded:
154
+ return "Please load Wikipedia content in the Settings tab first."
155
+
156
+ # Check cache first
157
+ cached_answer = self.cache.get(query)
158
+ if cached_answer:
159
+ return cached_answer
160
+
161
+ # Try to extract a specific section
162
+ extracted_section = self.extract_section(query)
163
+ if extracted_section:
164
+ self.cache.set(query, extracted_section)
165
+ return f"[Section Found] {extracted_section}"
166
+
167
+ # Use the QA chain
168
+ try:
169
+ logger.info(f"Processing query: {query}")
170
+ result = self.qa_chain.invoke({"question": query})
171
+ answer = result.get("answer", "No answer found")
172
+ self.cache.set(query, answer)
173
+ return answer
174
+ except Exception as e:
175
+ logger.error(f"Error in QA chain: {e}")
176
+ return f"Error processing query: {str(e)}"
177
+
178
+ # Initialize system
179
+ qa_system = GenAIQASystem()
180
+
181
+ # Define Gradio interface
182
+ with gr.Blocks(title="Generative AI Q/A System") as demo:
183
+ gr.Markdown("# Generative AI Q/A System")
184
+ gr.Markdown("Ask questions about Generative AI using this LangChain-based Q/A system")
185
+
186
+ with gr.Tab("Chat"):
187
+ chatbot = gr.Chatbot()
188
+ msg = gr.Textbox(label="Your Question")
189
+ clear = gr.Button("Clear")
190
+
191
+ def respond(message, history):
192
+ response = qa_system.process_query(message)
193
+ history.append((message, response))
194
+ return "", history
195
+
196
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
197
+ clear.click(lambda: [], None, chatbot, queue=False)
198
+
199
+ with gr.Tab("Settings"):
200
+ with gr.Group():
201
+ gr.Markdown("### Step 1: Set OpenAI API Key")
202
+ api_key_input = gr.Textbox(type="password", label="OpenAI API Key")
203
+ api_submit = gr.Button("Set API Key")
204
+ api_status = gr.Textbox(label="API Status", interactive=False)
205
+
206
+ with gr.Group():
207
+ gr.Markdown("### Step 2: Load Wikipedia Content")
208
+ load_wiki_button = gr.Button("Load Wikipedia Content")
209
+ wiki_status = gr.Textbox(label="Loading Status", interactive=False)
210
+
211
+ api_submit.click(qa_system.set_api_key, [api_key_input], [api_status])
212
+ load_wiki_button.click(qa_system.load_wikipedia, [], [wiki_status])
213
+
214
+ gr.Markdown("## About")
215
+ gr.Markdown("""
216
+ This Q/A system uses LangChain and OpenAI to answer questions based on the Wikipedia page about Generative AI.
217
+
218
+ Features:
219
+ - Caching mechanism to avoid repeating work
220
+ - Function calls to extract specific sections
221
+ - Logging to track processing
222
+
223
+ Created by Anjali Haryani
224
+ """)
225
 
226
  if __name__ == "__main__":
227
  demo.launch()