HaryaniAnjali's picture
Update app.py
91ef8ce verified
import os
import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.callbacks.base import BaseCallbackHandler
from langchain.memory import ConversationBufferMemory
import traceback
# --- Memory Cache ---
class MemoryCache:
def __init__(self):
self.cache = {}
def get(self, query: str):
if query in self.cache:
print(f"Cache hit: {query}")
return self.cache.get(query)
def set(self, query: str, response: str):
print(f"Saving to cache: {query}")
self.cache[query] = response
# --- Callback Logger ---
class LoggingCallbackHandler(BaseCallbackHandler):
def __init__(self):
self.logs = []
def on_chain_start(self, serialized, inputs, **kwargs):
self.logs.append(f"Chain start. Inputs: {inputs}")
print(f"Chain start. Inputs: {inputs}")
def on_chain_end(self, outputs, **kwargs):
self.logs.append(f"Chain end. Outputs: {outputs}")
print(f"Chain end. Outputs: {outputs}")
def on_retriever_start(self, *args, **kwargs):
self.logs.append("Retrieval start.")
print("Retrieval start.")
def on_retriever_end(self, *args, **kwargs):
self.logs.append("Retrieval end.")
print("Retrieval end.")
def on_llm_start(self, *args, **kwargs):
self.logs.append("LLM start.")
print("LLM start.")
def on_llm_end(self, result, *args, **kwargs):
try:
final_text = result.generations[0][0].text
self.logs.append(f"LLM end. Text: {final_text}")
print(f"LLM end. Text: {final_text}")
except Exception as e:
self.logs.append(f"LLM error: {e}")
print(f"LLM error: {e}")
def get_logs(self):
return "\n".join(self.logs)
def clear_logs(self):
self.logs = []
# --- Q&A System ---
class GenAIQASystem:
def __init__(self):
self.cache = MemoryCache()
self.callback_handler = LoggingCallbackHandler()
self.content = None
self.qa_chain = None
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
self.initialized = False
def initialize(self, api_key=None):
try:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
if "OPENAI_API_KEY" not in os.environ:
return False, "OpenAI API key is not set"
if self.initialized:
return True, "System already initialized"
print("Loading Wikipedia page content for Generative artificial intelligence")
loader = WikipediaLoader(query="Generative artificial intelligence")
docs = loader.load()
if not docs:
return False, "Wikipedia content not loaded. Check query or connection."
self.content = docs[0].page_content
print("Page loaded")
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
texts = text_splitter.split_text(self.content)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
if os.path.exists("faiss_index"):
print("Loading FAISS index from disk...")
vectorstore = FAISS.load_local("faiss_index", embeddings)
else:
print("Creating new FAISS index...")
vectorstore = FAISS.from_texts(texts, embeddings)
vectorstore.save_local("faiss_index")
llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0,
callbacks=[self.callback_handler]
)
self.qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=self.memory,
callbacks=[self.callback_handler]
)
self.initialized = True
return True, "System initialized successfully"
except Exception as e:
print(traceback.format_exc())
return False, f"Error initializing system: {str(e)}"
def process_query(self, query):
if not self.initialized:
return "System not initialized. Please set your OpenAI API key first."
cached_answer = self.cache.get(query)
if cached_answer:
return f"[Cache] Answer:\n{cached_answer}"
self.callback_handler.clear_logs()
print("\n[Retrieval] Processing query...")
result = self.qa_chain({"question": query})
answer = result.get("answer", "No answer found")
self.cache.set(query, answer)
return answer
def get_logs(self):
return self.callback_handler.get_logs()
# --- Create System Instance ---
qa_system = GenAIQASystem()
# --- Gradio Interface ---
def set_api_key(api_key):
success, message = qa_system.initialize(api_key)
return message
def respond(message, history):
response = qa_system.process_query(message)
history.append((message, response))
return history
def view_logs():
return qa_system.get_logs()
# --- Gradio UI ---
with gr.Blocks(title="Generative AI Q/A System") as demo:
gr.Markdown("# Generative AI Q/A System")
gr.Markdown("Ask questions about Generative AI using this LangChain-based Q/A system.")
with gr.Tab("Chat"):
chat_interface = gr.ChatInterface(fn=respond)
with gr.Tab("System Logs"):
logs_output = gr.Textbox(label="System Logs", lines=20)
view_logs_button = gr.Button("View Logs")
view_logs_button.click(view_logs, [], logs_output)
with gr.Tab("Settings"):
api_key_input = gr.Textbox(type="password", label="OpenAI API Key")
api_submit = gr.Button("Set API Key")
api_status = gr.Textbox(label="Status")
api_submit.click(set_api_key, [api_key_input], [api_status])
gr.Markdown("## About")
gr.Markdown("""
This Q/A system uses LangChain and OpenAI to answer questions based on the Wikipedia page about Generative AI.
Features:
- Caching mechanism to avoid repeating work
- Callback logging to track processing
- Persistent vector database (FAISS)
Created by Anjali Haryani (Modified for Hugging Face deployment)
""")
if __name__ == "__main__":
demo.launch()