|
|
import os |
|
|
import gradio as gr |
|
|
from langchain.chat_models import ChatOpenAI |
|
|
from langchain.document_loaders import WikipediaLoader |
|
|
from langchain.text_splitter import CharacterTextSplitter |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain.chains import ConversationalRetrievalChain |
|
|
from langchain.callbacks.base import BaseCallbackHandler |
|
|
from langchain.memory import ConversationBufferMemory |
|
|
import traceback |
|
|
|
|
|
|
|
|
class MemoryCache: |
|
|
def __init__(self): |
|
|
self.cache = {} |
|
|
|
|
|
def get(self, query: str): |
|
|
if query in self.cache: |
|
|
print(f"Cache hit: {query}") |
|
|
return self.cache.get(query) |
|
|
|
|
|
def set(self, query: str, response: str): |
|
|
print(f"Saving to cache: {query}") |
|
|
self.cache[query] = response |
|
|
|
|
|
|
|
|
class LoggingCallbackHandler(BaseCallbackHandler): |
|
|
def __init__(self): |
|
|
self.logs = [] |
|
|
|
|
|
def on_chain_start(self, serialized, inputs, **kwargs): |
|
|
self.logs.append(f"Chain start. Inputs: {inputs}") |
|
|
print(f"Chain start. Inputs: {inputs}") |
|
|
|
|
|
def on_chain_end(self, outputs, **kwargs): |
|
|
self.logs.append(f"Chain end. Outputs: {outputs}") |
|
|
print(f"Chain end. Outputs: {outputs}") |
|
|
|
|
|
def on_retriever_start(self, *args, **kwargs): |
|
|
self.logs.append("Retrieval start.") |
|
|
print("Retrieval start.") |
|
|
|
|
|
def on_retriever_end(self, *args, **kwargs): |
|
|
self.logs.append("Retrieval end.") |
|
|
print("Retrieval end.") |
|
|
|
|
|
def on_llm_start(self, *args, **kwargs): |
|
|
self.logs.append("LLM start.") |
|
|
print("LLM start.") |
|
|
|
|
|
def on_llm_end(self, result, *args, **kwargs): |
|
|
try: |
|
|
final_text = result.generations[0][0].text |
|
|
self.logs.append(f"LLM end. Text: {final_text}") |
|
|
print(f"LLM end. Text: {final_text}") |
|
|
except Exception as e: |
|
|
self.logs.append(f"LLM error: {e}") |
|
|
print(f"LLM error: {e}") |
|
|
|
|
|
def get_logs(self): |
|
|
return "\n".join(self.logs) |
|
|
|
|
|
def clear_logs(self): |
|
|
self.logs = [] |
|
|
|
|
|
|
|
|
class GenAIQASystem: |
|
|
def __init__(self): |
|
|
self.cache = MemoryCache() |
|
|
self.callback_handler = LoggingCallbackHandler() |
|
|
self.content = None |
|
|
self.qa_chain = None |
|
|
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) |
|
|
self.initialized = False |
|
|
|
|
|
def initialize(self, api_key=None): |
|
|
try: |
|
|
if api_key: |
|
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
|
|
|
if "OPENAI_API_KEY" not in os.environ: |
|
|
return False, "OpenAI API key is not set" |
|
|
|
|
|
if self.initialized: |
|
|
return True, "System already initialized" |
|
|
|
|
|
print("Loading Wikipedia page content for Generative artificial intelligence") |
|
|
loader = WikipediaLoader(query="Generative artificial intelligence") |
|
|
docs = loader.load() |
|
|
if not docs: |
|
|
return False, "Wikipedia content not loaded. Check query or connection." |
|
|
self.content = docs[0].page_content |
|
|
print("Page loaded") |
|
|
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
|
texts = text_splitter.split_text(self.content) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
if os.path.exists("faiss_index"): |
|
|
print("Loading FAISS index from disk...") |
|
|
vectorstore = FAISS.load_local("faiss_index", embeddings) |
|
|
else: |
|
|
print("Creating new FAISS index...") |
|
|
vectorstore = FAISS.from_texts(texts, embeddings) |
|
|
vectorstore.save_local("faiss_index") |
|
|
|
|
|
llm = ChatOpenAI( |
|
|
model_name="gpt-3.5-turbo", |
|
|
temperature=0, |
|
|
callbacks=[self.callback_handler] |
|
|
) |
|
|
|
|
|
self.qa_chain = ConversationalRetrievalChain.from_llm( |
|
|
llm=llm, |
|
|
retriever=vectorstore.as_retriever(), |
|
|
memory=self.memory, |
|
|
callbacks=[self.callback_handler] |
|
|
) |
|
|
|
|
|
self.initialized = True |
|
|
return True, "System initialized successfully" |
|
|
|
|
|
except Exception as e: |
|
|
print(traceback.format_exc()) |
|
|
return False, f"Error initializing system: {str(e)}" |
|
|
|
|
|
def process_query(self, query): |
|
|
if not self.initialized: |
|
|
return "System not initialized. Please set your OpenAI API key first." |
|
|
|
|
|
cached_answer = self.cache.get(query) |
|
|
if cached_answer: |
|
|
return f"[Cache] Answer:\n{cached_answer}" |
|
|
|
|
|
self.callback_handler.clear_logs() |
|
|
print("\n[Retrieval] Processing query...") |
|
|
result = self.qa_chain({"question": query}) |
|
|
answer = result.get("answer", "No answer found") |
|
|
self.cache.set(query, answer) |
|
|
|
|
|
return answer |
|
|
|
|
|
def get_logs(self): |
|
|
return self.callback_handler.get_logs() |
|
|
|
|
|
|
|
|
qa_system = GenAIQASystem() |
|
|
|
|
|
|
|
|
def set_api_key(api_key): |
|
|
success, message = qa_system.initialize(api_key) |
|
|
return message |
|
|
|
|
|
def respond(message, history): |
|
|
response = qa_system.process_query(message) |
|
|
history.append((message, response)) |
|
|
return history |
|
|
|
|
|
def view_logs(): |
|
|
return qa_system.get_logs() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Generative AI Q/A System") as demo: |
|
|
gr.Markdown("# Generative AI Q/A System") |
|
|
gr.Markdown("Ask questions about Generative AI using this LangChain-based Q/A system.") |
|
|
|
|
|
with gr.Tab("Chat"): |
|
|
chat_interface = gr.ChatInterface(fn=respond) |
|
|
|
|
|
with gr.Tab("System Logs"): |
|
|
logs_output = gr.Textbox(label="System Logs", lines=20) |
|
|
view_logs_button = gr.Button("View Logs") |
|
|
view_logs_button.click(view_logs, [], logs_output) |
|
|
|
|
|
with gr.Tab("Settings"): |
|
|
api_key_input = gr.Textbox(type="password", label="OpenAI API Key") |
|
|
api_submit = gr.Button("Set API Key") |
|
|
api_status = gr.Textbox(label="Status") |
|
|
api_submit.click(set_api_key, [api_key_input], [api_status]) |
|
|
|
|
|
gr.Markdown("## About") |
|
|
gr.Markdown(""" |
|
|
This Q/A system uses LangChain and OpenAI to answer questions based on the Wikipedia page about Generative AI. |
|
|
|
|
|
Features: |
|
|
- Caching mechanism to avoid repeating work |
|
|
- Callback logging to track processing |
|
|
- Persistent vector database (FAISS) |
|
|
|
|
|
Created by Anjali Haryani (Modified for Hugging Face deployment) |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |