HaryaniAnjali commited on
Commit
dc2cc17
·
verified ·
1 Parent(s): 4f007ec

Upload app-py.py

Browse files
Files changed (1) hide show
  1. app-py.py +246 -0
app-py.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from langchain.chat_models import ChatOpenAI
4
+ from langchain.document_loaders import WikipediaLoader
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.callbacks.base import BaseCallbackHandler
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.chains import ConversationalRetrievalChain
12
+
13
+ # Memory cache to store query answers
14
+ class MemoryCache:
15
+ def __init__(self):
16
+ self.cache = {}
17
+
18
+ def get(self, query: str):
19
+ if query in self.cache:
20
+ print(f"Cache hit: {query}")
21
+ return self.cache.get(query)
22
+
23
+ def set(self, query: str, response: str):
24
+ print(f"Saving to cache: {query}")
25
+ self.cache[query] = response
26
+
27
+ # Callback handler for logging key steps
28
+ class LoggingCallbackHandler(BaseCallbackHandler):
29
+ def __init__(self):
30
+ self.logs = []
31
+
32
+ def on_chain_start(self, serialized, inputs, **kwargs):
33
+ self.logs.append(f"Chain start. Inputs: {inputs}")
34
+ print(f"Chain start. Inputs: {inputs}")
35
+
36
+ def on_chain_end(self, outputs, **kwargs):
37
+ self.logs.append(f"Chain end. Outputs: {outputs}")
38
+ print(f"Chain end. Outputs: {outputs}")
39
+
40
+ def on_retriever_start(self, *args, **kwargs):
41
+ self.logs.append("Retrieval start.")
42
+ print("Retrieval start.")
43
+
44
+ def on_retriever_end(self, *args, **kwargs):
45
+ self.logs.append("Retrieval end.")
46
+ print("Retrieval end.")
47
+
48
+ def on_llm_start(self, *args, **kwargs):
49
+ self.logs.append("LLM start.")
50
+ print("LLM start.")
51
+
52
+ def on_llm_end(self, result, *args, **kwargs):
53
+ try:
54
+ final_text = result.generations[0][0].text
55
+ self.logs.append(f"LLM end. Text: {final_text}")
56
+ print(f"LLM end. Text: {final_text}")
57
+ except Exception as e:
58
+ self.logs.append(f"LLM error: {e}")
59
+ print(f"LLM error: {e}")
60
+
61
+ def get_logs(self):
62
+ return "\n".join(self.logs)
63
+
64
+ def clear_logs(self):
65
+ self.logs = []
66
+
67
+ # Function to extract a specific section from the content
68
+ def extract_section(query: str, content: str) -> str:
69
+ query_lower = query.lower()
70
+ lower_content = content.lower()
71
+
72
+ # If the query asks about early history
73
+ if "early history" in query_lower:
74
+ header = "== early history =="
75
+ start_index = lower_content.find(header)
76
+ if start_index != -1:
77
+ end_index = content.find("\n==", start_index + len(header))
78
+ print(f"Found header: {header}")
79
+ return content[start_index:end_index].strip() if end_index != -1 else content[start_index:].strip()
80
+ else:
81
+ print(f"Header not found: {header}")
82
+ # If the query asks about models
83
+ elif "generative models" in query_lower:
84
+ header = "== generative models =="
85
+ start_index = lower_content.find(header)
86
+ if start_index != -1:
87
+ end_index = content.find("\n==", start_index + len(header))
88
+ print(f"Found header: {header}")
89
+ return content[start_index:end_index].strip() if end_index != -1 else content[start_index:].strip()
90
+ else:
91
+ print(f"Header not found: {header}")
92
+ # If the query asks about applications
93
+ elif "academic artificial intelligence" in query_lower:
94
+ header = "== academic artificial intelligence =="
95
+ start_index = lower_content.find(header.lower())
96
+ if start_index != -1:
97
+ end_index = content.find("\n==", start_index + len(header))
98
+ print(f"Found header: {header}")
99
+ return content[start_index:end_index].strip() if end_index != -1 else content[start_index:].strip()
100
+ else:
101
+ print(f"Header not found: {header}")
102
+ return None
103
+
104
+ # Main class for the Q/A system
105
+ class GenAIQASystem:
106
+ def __init__(self):
107
+ self.cache = MemoryCache()
108
+ self.callback_handler = LoggingCallbackHandler()
109
+ self.content = None
110
+ self.qa_chain = None
111
+ self.memory = ConversationBufferMemory(
112
+ memory_key="chat_history",
113
+ return_messages=True
114
+ )
115
+ self.initialized = False
116
+
117
+ def initialize(self, api_key=None):
118
+ if api_key:
119
+ os.environ["OPENAI_API_KEY"] = api_key
120
+
121
+ if "OPENAI_API_KEY" not in os.environ:
122
+ return False, "OpenAI API key is not set"
123
+
124
+ if self.initialized:
125
+ return True, "System already initialized"
126
+
127
+ try:
128
+ # Loading Wikipedia page for Generative AI
129
+ print("Loading Wikipedia page content for Generative artificial intelligence")
130
+ loader = WikipediaLoader("Generative artificial intelligence")
131
+ docs = loader.load()
132
+ self.content = docs[0].page_content
133
+ print("Page loaded\n")
134
+
135
+ # Split the content into small chunks
136
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
137
+ texts = text_splitter.split_text(self.content)
138
+
139
+ # Create a vector store using embeddings from the text chunks
140
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
141
+ vectorstore = FAISS.from_texts(texts, embeddings)
142
+
143
+ # Set up the LLM with OpenAI model
144
+ llm = ChatOpenAI(
145
+ model="gpt-3.5-turbo",
146
+ temperature=0,
147
+ callbacks=[self.callback_handler]
148
+ )
149
+
150
+ # Use conversational retrieval chain for chat
151
+ self.qa_chain = ConversationalRetrievalChain.from_llm(
152
+ llm=llm,
153
+ retriever=vectorstore.as_retriever(),
154
+ memory=self.memory,
155
+ callbacks=[self.callback_handler]
156
+ )
157
+
158
+ self.initialized = True
159
+ return True, "System initialized successfully"
160
+ except Exception as e:
161
+ return False, f"Error initializing system: {str(e)}"
162
+
163
+ def process_query(self, query):
164
+ if not self.initialized:
165
+ return "System not initialized. Please set your OpenAI API key first."
166
+
167
+ # Check if the answer is in the cache
168
+ cached_answer = self.cache.get(query)
169
+ if cached_answer:
170
+ return f"[Cache] Answer:\n{cached_answer}"
171
+
172
+ # Try to extract a specific section from the content
173
+ extracted_section = extract_section(query, self.content)
174
+ if extracted_section:
175
+ self.cache.set(query, extracted_section)
176
+ return f"[Function Calling] Section from content:\n{extracted_section}"
177
+
178
+ # Use the retrieval Q/A chain to get the answer
179
+ self.callback_handler.clear_logs()
180
+ print("\n[Retrieval] Processing query...")
181
+ result = self.qa_chain({"question": query})
182
+ answer = result.get("answer", "No answer found")
183
+ self.cache.set(query, answer)
184
+
185
+ return answer
186
+
187
+ def get_logs(self):
188
+ return self.callback_handler.get_logs()
189
+
190
+ # Initialize the system
191
+ qa_system = GenAIQASystem()
192
+
193
+ # Define the Gradio interface
194
+ def set_api_key(api_key):
195
+ success, message = qa_system.initialize(api_key)
196
+ return message
197
+
198
+ def respond(message, history):
199
+ if not qa_system.initialized:
200
+ return "Please set your OpenAI API key first in the Settings tab."
201
+
202
+ response = qa_system.process_query(message)
203
+ return response
204
+
205
+ def view_logs():
206
+ return qa_system.get_logs()
207
+
208
+ # Gradio interface
209
+ with gr.Blocks(title="Generative AI Q/A System") as demo:
210
+ gr.Markdown("# Generative AI Q/A System")
211
+ gr.Markdown("Ask questions about Generative AI using this LangChain-based Q/A system")
212
+
213
+ with gr.Tab("Chat"):
214
+ chatbot = gr.Chatbot()
215
+ msg = gr.Textbox(label="Your Question")
216
+ clear = gr.Button("Clear")
217
+
218
+ msg.submit(respond, [msg, chatbot], [chatbot])
219
+ clear.click(lambda: None, None, chatbot, queue=False)
220
+
221
+ with gr.Tab("System Logs"):
222
+ logs_output = gr.Textbox(label="System Logs", lines=20)
223
+ view_logs_button = gr.Button("View Logs")
224
+ view_logs_button.click(view_logs, [], logs_output)
225
+
226
+ with gr.Tab("Settings"):
227
+ api_key_input = gr.Textbox(type="password", label="OpenAI API Key")
228
+ api_submit = gr.Button("Set API Key")
229
+ api_status = gr.Textbox(label="Status")
230
+
231
+ api_submit.click(set_api_key, [api_key_input], [api_status])
232
+
233
+ gr.Markdown("## About")
234
+ gr.Markdown("""
235
+ This Q/A system uses LangChain and OpenAI to answer questions based on the Wikipedia page about Generative AI.
236
+
237
+ Features:
238
+ - Caching mechanism to avoid repeating work
239
+ - Function calls to extract specific details
240
+ - Callback logging to track processing
241
+
242
+ Created by Anjali Haryani (Modified for Hugging Face deployment)
243
+ """)
244
+
245
+ if __name__ == "__main__":
246
+ demo.launch()