HaryaniAnjali commited on
Commit
c4c77a9
·
verified ·
1 Parent(s): 6b5d8c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -207
app.py CHANGED
@@ -1,227 +1,137 @@
1
- import os
2
- import logging
3
  import gradio as gr
 
4
  from langchain.chains import ConversationalRetrievalChain
5
  from langchain_openai import ChatOpenAI
6
- from langchain.memory import ConversationBufferMemory
7
- from langchain_community.vectorstores import FAISS
8
- from langchain_openai import OpenAIEmbeddings
9
  from langchain_community.document_loaders import WikipediaLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
11
  from langchain.callbacks.base import BaseCallbackHandler
12
 
13
- # Setup logging
 
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
- # Memory cache for storing answers
18
- class MemoryCache:
19
- def __init__(self):
20
- self.cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- def get(self, query: str):
23
- if query in self.cache:
24
- logger.info(f"Cache hit: {query}")
25
- return self.cache.get(query)
26
- return None
27
 
28
- def set(self, query: str, response: str):
29
- logger.info(f"Saving to cache: {query}")
30
- self.cache[query] = response
31
 
32
- # Callback handler for logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class LoggingCallbackHandler(BaseCallbackHandler):
34
  def on_chain_start(self, serialized, inputs, **kwargs):
35
- logger.info(f"Chain start. Inputs: {inputs}")
36
 
37
  def on_chain_end(self, outputs, **kwargs):
38
- logger.info(f"Chain end. Outputs: {outputs}")
39
-
40
- def on_retriever_start(self, *args, **kwargs):
41
- logger.info("Retrieval start.")
42
-
43
- def on_retriever_end(self, *args, **kwargs):
44
- logger.info("Retrieval end.")
45
-
46
- def on_llm_start(self, *args, **kwargs):
47
- logger.info("LLM start.")
48
-
49
- def on_llm_end(self, result, *args, **kwargs):
50
- try:
51
- final_text = result.generations[0][0].text
52
- logger.info(f"LLM end. Text: {final_text}")
53
- except Exception as e:
54
- logger.error(f"LLM error: {e}")
55
-
56
- class GenAIQASystem:
57
- def __init__(self):
58
- self.cache = MemoryCache()
59
- self.callback_handler = LoggingCallbackHandler()
60
- self.content = None
61
- self.qa_chain = None
62
- self.memory = None
63
- self.wiki_loaded = False
64
- self.api_key_set = False
65
-
66
- def set_api_key(self, api_key):
67
- if not api_key:
68
- return "Please provide a valid API key."
69
-
70
- try:
71
- os.environ["OPENAI_API_KEY"] = api_key
72
- # Test if API key works
73
- embeddings = OpenAIEmbeddings()
74
- embeddings.embed_query("Test")
75
- self.api_key_set = True
76
- return "API key set successfully!"
77
- except Exception as e:
78
- logger.error(f"API key error: {e}")
79
- return f"Error setting API key: {str(e)}"
80
-
81
- def load_wikipedia(self):
82
- if not self.api_key_set:
83
- return "Please set your OpenAI API key first."
84
-
85
- if self.wiki_loaded:
86
- return "Wikipedia content already loaded."
87
-
88
- try:
89
- logger.info("Loading Wikipedia content for Generative artificial intelligence")
90
-
91
- # Load Wikipedia content
92
- loader = WikipediaLoader(query="Generative artificial intelligence", lang="en")
93
- documents = loader.load()
94
- self.content = documents[0].page_content
95
-
96
- # Split content into chunks
97
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
98
- chunks = text_splitter.split_text(self.content)
99
-
100
- # Create vector store
101
- embeddings = OpenAIEmbeddings()
102
- vectorstore = FAISS.from_texts(chunks, embeddings)
103
-
104
- # Initialize memory
105
- self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
106
-
107
- # Create QA Chain
108
- llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
109
- self.qa_chain = ConversationalRetrievalChain.from_llm(
110
- llm,
111
- retriever=vectorstore.as_retriever(),
112
- memory=self.memory,
113
- callbacks=[self.callback_handler]
114
- )
115
-
116
- self.wiki_loaded = True
117
- return "Wikipedia content loaded successfully!"
118
- except Exception as e:
119
- logger.error(f"Error loading Wikipedia: {e}")
120
- return f"Error loading Wikipedia: {str(e)}"
121
-
122
- def extract_section(self, query: str):
123
- """Extracts a specific section from the Wikipedia content."""
124
- if not self.content:
125
- return None
126
-
127
- query_lower = query.lower()
128
- content_lower = self.content.lower()
129
-
130
- # Dictionary of section headers to look for
131
- sections = {
132
- "early history": "== early history ==",
133
- "generative models": "== generative models ==",
134
- "academic artificial intelligence": "== academic artificial intelligence =="
135
- }
136
-
137
- # Check if query matches any section
138
- for key, header in sections.items():
139
- if key in query_lower:
140
- start_index = content_lower.find(header)
141
- if start_index != -1:
142
- logger.info(f"Found header: {header}")
143
- end_index = self.content.find("\n==", start_index + len(header))
144
- section_text = self.content[start_index:end_index].strip() if end_index != -1 else self.content[start_index:].strip()
145
- return section_text
146
-
147
- return None
148
-
149
- def process_query(self, query):
150
- if not self.api_key_set:
151
- return "Please set your OpenAI API key in the Settings tab first."
152
-
153
- if not self.wiki_loaded:
154
- return "Please load Wikipedia content in the Settings tab first."
155
-
156
- # Check cache first
157
- cached_answer = self.cache.get(query)
158
- if cached_answer:
159
- return cached_answer
160
-
161
- # Try to extract a specific section
162
- extracted_section = self.extract_section(query)
163
- if extracted_section:
164
- self.cache.set(query, extracted_section)
165
- return f"[Section Found] {extracted_section}"
166
-
167
- # Use the QA chain
168
- try:
169
- logger.info(f"Processing query: {query}")
170
- result = self.qa_chain.invoke({"question": query})
171
- answer = result.get("answer", "No answer found")
172
- self.cache.set(query, answer)
173
- return answer
174
- except Exception as e:
175
- logger.error(f"Error in QA chain: {e}")
176
- return f"Error processing query: {str(e)}"
177
-
178
- # Initialize system
179
- qa_system = GenAIQASystem()
180
-
181
- # Define Gradio interface
182
- with gr.Blocks(title="Generative AI Q/A System") as demo:
183
- gr.Markdown("# Generative AI Q/A System")
184
- gr.Markdown("Ask questions about Generative AI using this LangChain-based Q/A system")
185
-
186
- with gr.Tab("Chat"):
187
- chatbot = gr.Chatbot()
188
- msg = gr.Textbox(label="Your Question")
189
- clear = gr.Button("Clear")
190
-
191
- def respond(message, history):
192
- response = qa_system.process_query(message)
193
- history.append((message, response))
194
- return "", history
195
-
196
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
197
- clear.click(lambda: [], None, chatbot, queue=False)
198
-
199
- with gr.Tab("Settings"):
200
- with gr.Group():
201
- gr.Markdown("### Step 1: Set OpenAI API Key")
202
- api_key_input = gr.Textbox(type="password", label="OpenAI API Key")
203
- api_submit = gr.Button("Set API Key")
204
- api_status = gr.Textbox(label="API Status", interactive=False)
205
-
206
- with gr.Group():
207
- gr.Markdown("### Step 2: Load Wikipedia Content")
208
- load_wiki_button = gr.Button("Load Wikipedia Content")
209
- wiki_status = gr.Textbox(label="Loading Status", interactive=False)
210
-
211
- api_submit.click(qa_system.set_api_key, [api_key_input], [api_status])
212
- load_wiki_button.click(qa_system.load_wikipedia, [], [wiki_status])
213
-
214
- gr.Markdown("## About")
215
- gr.Markdown("""
216
- This Q/A system uses LangChain and OpenAI to answer questions based on the Wikipedia page about Generative AI.
217
-
218
- Features:
219
- - Caching mechanism to avoid repeating work
220
- - Function calls to extract specific sections
221
- - Logging to track processing
222
-
223
- Created by Anjali Haryani
224
- """)
225
 
226
  if __name__ == "__main__":
227
  demo.launch()
 
 
 
1
  import gradio as gr
2
+ import logging
3
  from langchain.chains import ConversationalRetrievalChain
4
  from langchain_openai import ChatOpenAI
5
+ from langchain.memory import ConversationBufferMemory # Using the updated memory package
6
+ from langchain_community.vectorstores import Chroma # Corrected import for Chroma
7
+ from langchain_openai import OpenAIEmbeddings # Updated import for OpenAIEmbeddings
8
  from langchain_community.document_loaders import WikipediaLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.tools import StructuredTool
11
  from langchain.callbacks.base import BaseCallbackHandler
12
 
13
+ # ================================
14
+ # Step 1: Setup Logging for Debugging
15
+ # ================================
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ # ================================
20
+ # Step 2: Load Wikipedia Data
21
+ # ================================
22
+ def fetch_wikipedia_content():
23
+ """Fetches Wikipedia content using LangChain's WikipediaLoader."""
24
+ loader = WikipediaLoader(query="Generative artificial intelligence", lang="en")
25
+ documents = loader.load()
26
+ return documents[0].page_content if documents else "Page not found."
27
+
28
+ wiki_text = fetch_wikipedia_content()
29
+
30
+ # ================================
31
+ # Step 3: Process Wikipedia Text for Retrieval
32
+ # ================================
33
+ def process_and_store_wikipedia(text):
34
+ """Splits Wikipedia content into chunks, embeds them, and stores in ChromaDB."""
35
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
36
+ chunks = splitter.split_text(text)
37
+
38
+ embeddings = OpenAIEmbeddings() # Using updated OpenAI embeddings
39
+ vectorstore = Chroma.from_texts(chunks, embedding=embeddings, persist_directory="/home/user/chroma_db") # Ensuring persistence
40
+ return vectorstore.as_retriever()
41
+
42
+ retriever = process_and_store_wikipedia(wiki_text)
43
+
44
+ # ================================
45
+ # Step 4: Initialize Chat Model and Memory
46
+ # ================================
47
+ llm = ChatOpenAI(model_name="gpt-4o")
48
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) # Initialize memory for conversation history
49
+
50
+ # ================================
51
+ # Step 5: Create Q/A Retrieval Chain
52
+ # ================================
53
+ qa_chain = ConversationalRetrievalChain.from_llm(
54
+ llm, retriever=retriever, memory=memory
55
+ )
56
+
57
+ # ================================
58
+ # Step 6: Implement Chatbot Response Function with Caching
59
+ # ================================
60
+ def ask_with_memory(query):
61
+ """Retrieves the answer from memory if available, otherwise fetches it using LangChain's Q/A chain."""
62
+
63
+ # Load chat history
64
+ chat_history = memory.load_memory_variables({})["chat_history"]
65
+
66
+ # Check if the exact query has been answered before
67
+ for i in range(len(chat_history) - 1):
68
+ if chat_history[i].content == query:
69
+ return chat_history[i + 1].content # Return cached answer
70
+
71
+ # If not cached, process the query
72
+ response = qa_chain.invoke({"question": query})["answer"]
73
+
74
+ # Save query-response pair in memory
75
+ memory.save_context({"question": query}, {"answer": response})
76
 
77
+ return response
 
 
 
 
78
 
 
 
 
79
 
80
+ # ================================
81
+ # Step 7: Implement Structured Function Calling for Section Extraction
82
+ # ================================
83
+ def extract_section_by_query(query: str) -> str:
84
+ """Finds and returns the most relevant section based on a user query using embeddings."""
85
+ vector_store = retriever # Use the existing retriever
86
+
87
+ # Retrieve the most relevant section
88
+ retrieved_docs = vector_store.get_relevant_documents(query)
89
+
90
+ if not retrieved_docs:
91
+ return "Section not found."
92
+
93
+ return f"Section: {retrieved_docs[0].metadata.get('title', 'Unknown')}\n\n{retrieved_docs[0].page_content}"
94
+
95
+ section_extraction_tool = StructuredTool.from_function(
96
+ extract_section_by_query,
97
+ name="extract_section_by_query",
98
+ description="Finds the most relevant Wikipedia section based on a user query using embeddings."
99
+ )
100
+
101
+ # ================================
102
+ # Step 8: Implement Callback Logging for Debugging
103
+ # ================================
104
  class LoggingCallbackHandler(BaseCallbackHandler):
105
  def on_chain_start(self, serialized, inputs, **kwargs):
106
+ logger.info(f"Starting chain execution with input: {inputs}")
107
 
108
  def on_chain_end(self, outputs, **kwargs):
109
+ logger.info(f"Chain execution finished. Output: {outputs}")
110
+
111
+ callback_handler = LoggingCallbackHandler()
112
+ qa_chain.callbacks = [callback_handler]
113
+
114
+ # ================================
115
+ # Step 9: Define Gradio Interface
116
+ # ================================
117
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
118
+ """
119
+ Processes user query and retrieves answers from Wikipedia-based Q/A system with caching.
120
+ """
121
+ return ask_with_memory(message)
122
+
123
+ # ================================
124
+ # Step 10: Create Gradio Interface
125
+ # ================================
126
+ demo = gr.ChatInterface(
127
+ respond,
128
+ additional_inputs=[
129
+ gr.Textbox(value="You are an AI expert answering questions about Generative AI.", label="System message"),
130
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
131
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
132
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
133
+ ],
134
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  if __name__ == "__main__":
137
  demo.launch()