HaryaniAnjali commited on
Commit
633adde
·
verified ·
1 Parent(s): 5d4f3f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -104
app.py CHANGED
@@ -1,129 +1,184 @@
1
  import os
 
2
  import gradio as gr
3
- from langchain.chat_models import ChatOpenAI
4
- from langchain.document_loaders import WikipediaLoader
5
- from langchain.text_splitter import CharacterTextSplitter
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
- from langchain.vectorstores import FAISS
8
  from langchain.chains import ConversationalRetrievalChain
 
9
  from langchain.memory import ConversationBufferMemory
10
- from langchain.embeddings import OpenAIEmbeddings
 
 
 
 
11
 
12
- # Simple memory cache
 
 
 
 
13
  class MemoryCache:
14
  def __init__(self):
15
  self.cache = {}
16
 
17
- def get(self, query):
18
- return self.cache.get(query)
 
 
 
19
 
20
- def set(self, query, response):
 
21
  self.cache[query] = response
22
 
23
- # Function to extract specific sections (simplified)
24
- def extract_section(query, content):
25
- query_lower = query.lower()
26
- if "early history" in query_lower:
27
- return "Information about the early history of Generative AI would appear here."
28
- elif "generative models" in query_lower:
29
- return "Information about generative models would appear here."
30
- elif "academic artificial intelligence" in query_lower:
31
- return "Information about academic artificial intelligence would appear here."
32
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Main QA class (simplified)
35
  class GenAIQASystem:
36
  def __init__(self):
37
  self.cache = MemoryCache()
 
38
  self.content = None
39
  self.qa_chain = None
40
- self.initialized = False
41
- self.memory = ConversationBufferMemory(
42
- memory_key="chat_history",
43
- return_messages=True
44
- )
45
-
46
- def initialize(self, api_key=None):
47
- if api_key:
48
- os.environ["OPENAI_API_KEY"] = api_key
49
-
50
- if not api_key and "OPENAI_API_KEY" not in os.environ:
51
- return False, "OpenAI API key is not set"
52
-
53
- if self.initialized:
54
- return True, "System already initialized"
55
 
56
  try:
57
- # Initialize with placeholder content for faster startup
58
- self.content = "This is placeholder content for Generative AI."
59
- self.initialized = True
60
- return True, "System initialized successfully"
 
 
61
  except Exception as e:
62
- return False, f"Error initializing system: {str(e)}"
 
63
 
64
  def load_wikipedia(self):
65
- if not self.initialized:
66
- return "System not initialized. Please set your OpenAI API key first."
67
-
 
 
 
68
  try:
69
- # Loading Wikipedia page for Generative AI
70
- loader = WikipediaLoader("Generative artificial intelligence")
71
- docs = loader.load()
72
- self.content = docs[0].page_content
 
 
73
 
74
  # Split content into chunks
75
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
76
- texts = text_splitter.split_text(self.content)
77
 
78
  # Create vector store
79
  embeddings = OpenAIEmbeddings()
80
- vectorstore = FAISS.from_texts(texts, embeddings)
81
 
82
- # Set up QA chain
83
- llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
 
 
 
84
  self.qa_chain = ConversationalRetrievalChain.from_llm(
85
- llm=llm,
86
  retriever=vectorstore.as_retriever(),
87
- memory=self.memory
 
88
  )
89
 
 
90
  return "Wikipedia content loaded successfully!"
91
  except Exception as e:
92
- return f"Error loading Wikipedia content: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def process_query(self, query):
95
- if not self.initialized:
96
- return "System not initialized. Please set your OpenAI API key first."
97
 
98
- # Check cache
 
 
 
99
  cached_answer = self.cache.get(query)
100
  if cached_answer:
101
- return f"[Cache] Answer:\n{cached_answer}"
102
-
103
- # Check for section extraction
104
- if self.content:
105
- extracted_section = extract_section(query, self.content)
106
- if extracted_section:
107
- self.cache.set(query, extracted_section)
108
- return f"[Function Calling] Section from content:\n{extracted_section}"
109
-
110
- # If QA chain is ready, use it
111
- if self.qa_chain:
112
- try:
113
- result = self.qa_chain({"question": query})
114
- answer = result.get("answer", "No answer found")
115
- self.cache.set(query, answer)
116
- return answer
117
- except Exception as e:
118
- return f"Error processing query: {str(e)}"
119
- else:
120
- # If Wikipedia content isn't loaded yet
121
- return "Please load Wikipedia content first by clicking 'Load Wikipedia' in the Settings tab."
122
 
123
  # Initialize system
124
  qa_system = GenAIQASystem()
125
 
126
- # Gradio interface
127
  with gr.Blocks(title="Generative AI Q/A System") as demo:
128
  gr.Markdown("# Generative AI Q/A System")
129
  gr.Markdown("Ask questions about Generative AI using this LangChain-based Q/A system")
@@ -132,40 +187,41 @@ with gr.Blocks(title="Generative AI Q/A System") as demo:
132
  chatbot = gr.Chatbot()
133
  msg = gr.Textbox(label="Your Question")
134
  clear = gr.Button("Clear")
135
-
136
- def respond(message, history):
137
- try:
138
- response = qa_system.process_query(message)
139
- return "", history + [(message, response)]
140
- except Exception as e:
141
- error_message = f"Error processing query: {str(e)}"
142
- return "", history + [(message, error_message)]
143
 
 
 
 
 
144
 
145
- msg.submit(respond, [msg, chatbot], [chatbot])
146
- clear.click(lambda: None, None, chatbot, queue=False)
147
 
148
  with gr.Tab("Settings"):
149
- api_key_input = gr.Textbox(type="password", label="OpenAI API Key")
150
- api_submit = gr.Button("Set API Key")
151
- api_status = gr.Textbox(label="Status")
152
- load_wiki = gr.Button("Load Wikipedia Content")
153
- wiki_status = gr.Textbox(label="Wikipedia Status")
154
-
155
- def set_api_key(api_key):
156
- success, message = qa_system.initialize(api_key)
157
- return message
 
158
 
159
- api_submit.click(set_api_key, [api_key_input], [api_status])
160
- load_wiki.click(qa_system.load_wikipedia, [], wiki_status)
161
 
162
  gr.Markdown("## About")
163
  gr.Markdown("""
164
  This Q/A system uses LangChain and OpenAI to answer questions based on the Wikipedia page about Generative AI.
165
 
166
- Created by Anjali Haryani (Modified for Hugging Face deployment)
 
 
 
 
 
167
  """)
168
 
169
- # Launch the app
170
  if __name__ == "__main__":
171
  demo.launch()
 
1
  import os
2
+ import logging
3
  import gradio as gr
 
 
 
 
 
4
  from langchain.chains import ConversationalRetrievalChain
5
+ from langchain_openai import ChatOpenAI
6
  from langchain.memory import ConversationBufferMemory
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_openai import OpenAIEmbeddings
9
+ from langchain_community.document_loaders import WikipediaLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain.callbacks.base import BaseCallbackHandler
12
 
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Memory cache for storing answers
18
  class MemoryCache:
19
  def __init__(self):
20
  self.cache = {}
21
 
22
+ def get(self, query: str):
23
+ if query in self.cache:
24
+ logger.info(f"Cache hit: {query}")
25
+ return self.cache.get(query)
26
+ return None
27
 
28
+ def set(self, query: str, response: str):
29
+ logger.info(f"Saving to cache: {query}")
30
  self.cache[query] = response
31
 
32
+ # Callback handler for logging
33
+ class LoggingCallbackHandler(BaseCallbackHandler):
34
+ def on_chain_start(self, serialized, inputs, **kwargs):
35
+ logger.info(f"Chain start. Inputs: {inputs}")
36
+
37
+ def on_chain_end(self, outputs, **kwargs):
38
+ logger.info(f"Chain end. Outputs: {outputs}")
39
+
40
+ def on_retriever_start(self, *args, **kwargs):
41
+ logger.info("Retrieval start.")
42
+
43
+ def on_retriever_end(self, *args, **kwargs):
44
+ logger.info("Retrieval end.")
45
+
46
+ def on_llm_start(self, *args, **kwargs):
47
+ logger.info("LLM start.")
48
+
49
+ def on_llm_end(self, result, *args, **kwargs):
50
+ try:
51
+ final_text = result.generations[0][0].text
52
+ logger.info(f"LLM end. Text: {final_text}")
53
+ except Exception as e:
54
+ logger.error(f"LLM error: {e}")
55
 
 
56
  class GenAIQASystem:
57
  def __init__(self):
58
  self.cache = MemoryCache()
59
+ self.callback_handler = LoggingCallbackHandler()
60
  self.content = None
61
  self.qa_chain = None
62
+ self.memory = None
63
+ self.wiki_loaded = False
64
+ self.api_key_set = False
65
+
66
+ def set_api_key(self, api_key):
67
+ if not api_key:
68
+ return "Please provide a valid API key."
 
 
 
 
 
 
 
 
69
 
70
  try:
71
+ os.environ["OPENAI_API_KEY"] = api_key
72
+ # Test if API key works
73
+ embeddings = OpenAIEmbeddings()
74
+ embeddings.embed_query("Test")
75
+ self.api_key_set = True
76
+ return "API key set successfully!"
77
  except Exception as e:
78
+ logger.error(f"API key error: {e}")
79
+ return f"Error setting API key: {str(e)}"
80
 
81
  def load_wikipedia(self):
82
+ if not self.api_key_set:
83
+ return "Please set your OpenAI API key first."
84
+
85
+ if self.wiki_loaded:
86
+ return "Wikipedia content already loaded."
87
+
88
  try:
89
+ logger.info("Loading Wikipedia content for Generative artificial intelligence")
90
+
91
+ # Load Wikipedia content
92
+ loader = WikipediaLoader(query="Generative artificial intelligence", lang="en")
93
+ documents = loader.load()
94
+ self.content = documents[0].page_content
95
 
96
  # Split content into chunks
97
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
98
+ chunks = text_splitter.split_text(self.content)
99
 
100
  # Create vector store
101
  embeddings = OpenAIEmbeddings()
102
+ vectorstore = FAISS.from_texts(chunks, embeddings)
103
 
104
+ # Initialize memory
105
+ self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
106
+
107
+ # Create QA Chain
108
+ llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
109
  self.qa_chain = ConversationalRetrievalChain.from_llm(
110
+ llm,
111
  retriever=vectorstore.as_retriever(),
112
+ memory=self.memory,
113
+ callbacks=[self.callback_handler]
114
  )
115
 
116
+ self.wiki_loaded = True
117
  return "Wikipedia content loaded successfully!"
118
  except Exception as e:
119
+ logger.error(f"Error loading Wikipedia: {e}")
120
+ return f"Error loading Wikipedia: {str(e)}"
121
+
122
+ def extract_section(self, query: str):
123
+ """Extracts a specific section from the Wikipedia content."""
124
+ if not self.content:
125
+ return None
126
+
127
+ query_lower = query.lower()
128
+ content_lower = self.content.lower()
129
+
130
+ # Dictionary of section headers to look for
131
+ sections = {
132
+ "early history": "== early history ==",
133
+ "generative models": "== generative models ==",
134
+ "academic artificial intelligence": "== academic artificial intelligence =="
135
+ }
136
+
137
+ # Check if query matches any section
138
+ for key, header in sections.items():
139
+ if key in query_lower:
140
+ start_index = content_lower.find(header)
141
+ if start_index != -1:
142
+ logger.info(f"Found header: {header}")
143
+ end_index = self.content.find("\n==", start_index + len(header))
144
+ section_text = self.content[start_index:end_index].strip() if end_index != -1 else self.content[start_index:].strip()
145
+ return section_text
146
+
147
+ return None
148
 
149
  def process_query(self, query):
150
+ if not self.api_key_set:
151
+ return "Please set your OpenAI API key in the Settings tab first."
152
 
153
+ if not self.wiki_loaded:
154
+ return "Please load Wikipedia content in the Settings tab first."
155
+
156
+ # Check cache first
157
  cached_answer = self.cache.get(query)
158
  if cached_answer:
159
+ return cached_answer
160
+
161
+ # Try to extract a specific section
162
+ extracted_section = self.extract_section(query)
163
+ if extracted_section:
164
+ self.cache.set(query, extracted_section)
165
+ return f"[Section Found] {extracted_section}"
166
+
167
+ # Use the QA chain
168
+ try:
169
+ logger.info(f"Processing query: {query}")
170
+ result = self.qa_chain.invoke({"question": query})
171
+ answer = result.get("answer", "No answer found")
172
+ self.cache.set(query, answer)
173
+ return answer
174
+ except Exception as e:
175
+ logger.error(f"Error in QA chain: {e}")
176
+ return f"Error processing query: {str(e)}"
 
 
 
177
 
178
  # Initialize system
179
  qa_system = GenAIQASystem()
180
 
181
+ # Define Gradio interface
182
  with gr.Blocks(title="Generative AI Q/A System") as demo:
183
  gr.Markdown("# Generative AI Q/A System")
184
  gr.Markdown("Ask questions about Generative AI using this LangChain-based Q/A system")
 
187
  chatbot = gr.Chatbot()
188
  msg = gr.Textbox(label="Your Question")
189
  clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
190
 
191
+ def respond(message, history):
192
+ response = qa_system.process_query(message)
193
+ history.append((message, response))
194
+ return "", history
195
 
196
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
197
+ clear.click(lambda: [], None, chatbot, queue=False)
198
 
199
  with gr.Tab("Settings"):
200
+ with gr.Group():
201
+ gr.Markdown("### Step 1: Set OpenAI API Key")
202
+ api_key_input = gr.Textbox(type="password", label="OpenAI API Key")
203
+ api_submit = gr.Button("Set API Key")
204
+ api_status = gr.Textbox(label="API Status", interactive=False)
205
+
206
+ with gr.Group():
207
+ gr.Markdown("### Step 2: Load Wikipedia Content")
208
+ load_wiki_button = gr.Button("Load Wikipedia Content")
209
+ wiki_status = gr.Textbox(label="Loading Status", interactive=False)
210
 
211
+ api_submit.click(qa_system.set_api_key, [api_key_input], [api_status])
212
+ load_wiki_button.click(qa_system.load_wikipedia, [], [wiki_status])
213
 
214
  gr.Markdown("## About")
215
  gr.Markdown("""
216
  This Q/A system uses LangChain and OpenAI to answer questions based on the Wikipedia page about Generative AI.
217
 
218
+ Features:
219
+ - Caching mechanism to avoid repeating work
220
+ - Function calls to extract specific sections
221
+ - Logging to track processing
222
+
223
+ Created by Anjali Haryani
224
  """)
225
 
 
226
  if __name__ == "__main__":
227
  demo.launch()