Copy-AI / app.py
Wajahat698's picture
Update app.py
30cfc45 verified
raw
history blame
6.93 kB
import logging
import os
import requests
from dotenv import load_dotenv
import openai
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents import tool, AgentExecutor
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import CharacterTextSplitter
import serpapi
import streamlit as st
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables from .env file
load_dotenv()
# Define and validate API keys
openai_api_key = os.getenv("OPENAI_API_KEY")
serper_api_key = os.getenv("SERPER_API_KEY")
if not openai_api_key or not serper_api_key:
logger.error("API keys are not set properly.")
st.error("API keys for OpenAI and SERPER must be set in the .env file.")
st.stop()
# Initialize OpenAI client
try:
openai.api_key = openai_api_key
logger.info("OpenAI client initialized successfully.")
except Exception as e:
logger.error(f"Error initializing OpenAI client: {e}")
st.error(f"Error initializing OpenAI client: {e}")
st.stop()
# Load knowledge base
def load_knowledge_base():
try:
loader = TextLoader("./data_source/time_to_rethink_trust_book.md")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
return docs
except Exception as e:
logger.error(f"Error loading knowledge base: {e}")
st.error(f"Error loading knowledge base: {e}")
st.stop()
knowledge_base = load_knowledge_base()
# Initialize embeddings and FAISS index
try:
embeddings = OpenAIEmbeddings()
db = FAISS.from_documents(knowledge_base, embeddings)
except Exception as e:
logger.error(f"Error initializing FAISS index: {e}")
st.error(f"Error initializing FAISS index: {e}")
st.stop()
# Define search function for knowledge base
def search_knowledge_base(query):
try:
output = db.similarity_search(query)
return output
except Exception as e:
logger.error(f"Error searching knowledge base: {e}")
return ["Error occurred during knowledge base search"]
# SERPER API Google Search function
def google_search(query):
try:
search_client = serpapi.Client(api_key=serper_api_key)
results = search_client.search({"engine": "google", "q": query})
snippets = [result["snippet"] for result in results.get("organic_results", [])]
return snippets
except requests.exceptions.HTTPError as http_err:
logger.error(f"HTTP error occurred: {http_err}")
return ["HTTP error occurred during Google search"]
except Exception as e:
logger.error(f"General Error: {e}")
return ["Error occurred during Google search"]
# RAG response function
def rag_response(query):
try:
retrieved_docs = search_knowledge_base(query)
context = "\n".join(doc.page_content for doc in retrieved_docs)
prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
llm = ChatOpenAI(model="gpt-4o", temperature=0.5, api_key=openai_api_key)
response = llm.invoke(prompt)
return response.content
except Exception as e:
logger.error(f"Error generating RAG response: {e}")
return "Error occurred during RAG response generation"
# Define tools using LangChain's `tool` decorator
@tool
def knowledge_base_tool(query: str):
"""
Tool function to query the knowledge base and retrieve a response.
Args:
query (str): The query to search the knowledge base.
Returns:
str: The response retrieved from the knowledge base.
"""
return rag_response(query)
@tool
def google_search_tool(query: str):
"""
Tool function to perform a Google search using the SERPER API.
Args:
query (str): The query to search on Google.
Returns:
list: List of snippets extracted from search results.
"""
return google_search(query)
tools = [knowledge_base_tool, google_search_tool]
# Create the prompt template
prompt_message = """
Act as an expert copywriter who specializes in creating compelling marketing copy using AI technologies.
Engage in a friendly and informative conversation based on the knowledge base.
Only proceed to create sales materials when the user explicitly requests it.
Work together with the user to update the outcome of the sales material.
"""
prompt_template = ChatPromptTemplate.from_messages(
[
("system", prompt_message),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
# Create Langchain Agent with specific model and temperature
try:
llm = ChatOpenAI(model="gpt-4o", temperature=0.5)
llm_with_tools = llm.bind_tools(tools)
except Exception as e:
logger.error(f"Error creating Langchain Agent: {e}")
st.error(f"Error creating Langchain Agent: {e}")
st.stop()
# Define the agent pipeline to handle the conversation flow
try:
agent = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_tool_messages(x["intermediate_steps"]),
"chat_history": lambda x: x["chat_history"],
}
| prompt_template
| llm_with_tools
| OpenAIToolsAgentOutputParser()
)
# Instantiate an AgentExecutor to execute the defined agent pipeline
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
except Exception as e:
logger.error(f"Error defining agent pipeline: {e}")
st.error(f"Error defining agent pipeline: {e}")
st.stop()
# Initialize chat history
chat_history = []
def chatbot_response(message, history):
try:
output = agent_executor.invoke({"input": message, "chat_history": chat_history})
chat_history.extend([HumanMessage(content=message), AIMessage(content=output["output"])])
return output["output"]
except Exception as e:
logger.error(f"Error generating chatbot response: {e}")
return "Error occurred during response generation"
# Streamlit app
# Create input field for user message
user_input = st.text_input("You:", "")
# Create a button for submitting the message
if st.button("Submit"):
if user_input:
response = chatbot_response(user_input, chat_history)
st.write("AI:", response)
else:
st.warning("Please enter a message.")