|
|
import gradio as gr |
|
|
import os |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.document_loaders import WebBaseLoader |
|
|
from langchain.tools import Tool |
|
|
from langchain.chains import RetrievalQA |
|
|
from langchain.chat_models import ChatOpenAI |
|
|
from langchain.agents import initialize_agent, AgentType |
|
|
from langchain.memory import ConversationBufferMemory |
|
|
|
|
|
|
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY') |
|
|
|
|
|
|
|
|
url = "https://www.halodesigns.in/" |
|
|
loader = WebBaseLoader(url) |
|
|
documents = loader.load() |
|
|
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
|
docs = text_splitter.split_documents(documents) |
|
|
|
|
|
|
|
|
embeddings = OpenAIEmbeddings() |
|
|
vector_store = FAISS.from_documents(docs, embeddings) |
|
|
retriever = vector_store.as_retriever(search_type="similarity", search_k=3) |
|
|
|
|
|
|
|
|
|
|
|
llm = ChatOpenAI(model="gpt-4o-mini") |
|
|
|
|
|
|
|
|
def document_retrieval(query: str): |
|
|
return retrieval_qa_chain({"query": query})["result"] |
|
|
|
|
|
|
|
|
retrieval_qa_chain = RetrievalQA.from_chain_type( |
|
|
llm=llm, |
|
|
retriever=retriever, |
|
|
return_source_documents=True |
|
|
) |
|
|
|
|
|
|
|
|
def get_document_summary(): |
|
|
summary_response = retrieval_qa_chain({"query": "Summarize the document in detail. Do Not Miss Any points."}) |
|
|
return summary_response["result"] |
|
|
|
|
|
|
|
|
document_summary = get_document_summary() |
|
|
|
|
|
|
|
|
def get_document_questions(): |
|
|
questions_response = retrieval_qa_chain({"query": "List all the possible questions based on the given context. Do Not Miss Any questions."}) |
|
|
return questions_response["result"] |
|
|
|
|
|
|
|
|
document_questions = get_document_questions() |
|
|
|
|
|
|
|
|
|
|
|
llm_tool = Tool( |
|
|
name="General Query LLM", |
|
|
func=lambda q: llm.predict(q), |
|
|
description="Uses LLM to answer general knowledge questions (e.g., greetings, sports, world events). Does NOT handle RAG-related queries.") |
|
|
|
|
|
document_retrieval_tool = Tool( |
|
|
name="Document Retrieval", |
|
|
func=document_retrieval, |
|
|
description=( |
|
|
f"This tool retrieves information that contains following information: \n" |
|
|
f"{document_summary}\n" |
|
|
f"Also the following questions: \n" |
|
|
f"{document_questions}" |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history") |
|
|
|
|
|
|
|
|
agent = initialize_agent( |
|
|
tools=[llm_tool, document_retrieval_tool], |
|
|
llm=llm, |
|
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, |
|
|
verbose=True, |
|
|
memory = memory |
|
|
) |
|
|
|
|
|
|
|
|
def chatbot_response(user_input,history): |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
response = agent.run(user_input) |
|
|
|
|
|
return response |
|
|
except Exception as e: |
|
|
return f"Error: {e}" |
|
|
|
|
|
|
|
|
gr.ChatInterface(fn=chatbot_response, title="Halo Designs Chatbot", theme="soft").launch(debug=True) |