File size: 6,859 Bytes
72c6641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8867a90
72c6641
8867a90
72c6641
 
 
 
 
 
 
 
 
8867a90
72c6641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8867a90
72c6641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# STREAMLIT VERSION 2.1 - PDF WORKING

import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.document_loaders import WebBaseLoader, YoutubeLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Qdrant
from langchain_openai import OpenAIEmbeddings
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import ChatOpenAI
from PIL import Image
from PyPDF2 import PdfReader
# from dotenv import load_dotenv

# Load secrets from .env file
# load_dotenv()


def extract_data_from_url(url):
    '''
    Extract the url content and return as a list of Document objects -> [Document].

    args: url (str)
    '''
    loader = WebBaseLoader(url)
    doc = loader.load()

    return doc


def extract_transcript_from_youtube_url(youtube_url):
    '''
    Extract the transcript of a YouTube video and return as a list of Document objects -> [Document].

    args: url (str): The url of the YouTube video
    '''
    youtube_loader = YoutubeLoader.from_youtube_url(
        youtube_url, add_video_info=False)
    transcript = youtube_loader.load()

    return transcript


def create_vectorstore_from_pdf(uploaded_pdf):
    '''
    Extract the text content of a PDF file, embed it and store in a vector db.

    args: uploaded pdf (file)
    '''
    pdf_reader = PdfReader(uploaded_pdf)

    text = ""
    for page in pdf_reader.pages:
        text += page.extract_text()

    text_splitter = RecursiveCharacterTextSplitter(
        separators=["\n", "\n\n", "\r", "\t", " "],
        chunk_size=1000,
        chunk_overlap=0,
    )
    text_chunks = text_splitter.split_text(text)

    embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
    vector_db = Qdrant.from_texts(
        text_chunks,
        embeddings,
        location=":memory:",  # Using in-memory storage
        collection_name="HarveDocs")

    return vector_db


def create_vectorstore_from_data(data):
    '''
    1. Split the text data into text chunks.
    2. Vectorize text chunks and store in a vector db.
    3. Return the vector db.

    args: data -> [document]: List of Document objects
    '''
    text_splitter = RecursiveCharacterTextSplitter(
        separators=["\n", "\n\n", "\r", "\t", " "],
        chunk_size=1000,
        chunk_overlap=0,
    )
    text_chunks = text_splitter.split_documents(data)
    embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
    vector_db = Qdrant.from_documents(
        text_chunks,
        embeddings,
        location=":memory:",  # Using in-memory storage
        collection_name="HarveDocs")

    return vector_db


def create_context_retriever_chain(vec_store):
    '''
    Get the context retriever chain to be used in the dialog chain.
    '''
    llm = ChatOpenAI(temperature=0.1, max_tokens=500)
    retriever = vec_store.as_retriever()
    prompt = ChatPromptTemplate.from_messages([
        MessagesPlaceholder(variable_name="chat_history"),
        ("user", "{input}"),
        ("user", "Based on the conversation above, create a search query that you will refer to, to get information that is relevant to the conversation.")
    ])

    retriever_chain = create_history_aware_retriever(llm, retriever, prompt)
    return retriever_chain


def create_dialog_rag_chain(retriever_chain):
    '''
    Get the conversation chain
    '''
    llm = ChatOpenAI(temperature=0.1, max_tokens=500)
    prompt = ChatPromptTemplate.from_messages([
        MessagesPlaceholder(variable_name="chat_history"),
        ("system",
         "Answer the user's questions based on the context below:\n{context}"),
        MessagesPlaceholder(variable_name="chat_history"),
        ("user", "{input}"),
    ])
    stuff_documents_chain = create_stuff_documents_chain(llm, prompt)

    return create_retrieval_chain(retriever_chain, stuff_documents_chain)


def get_response(query):
    '''
    Get response from the AI model
    '''
    # Dialog chain
    retrieval_chain = create_context_retriever_chain(
        st.session_state.vec_store)

    dialog_rag_chain = create_dialog_rag_chain(retrieval_chain)
    response = dialog_rag_chain.invoke({
        "chat_history": st.session_state.chat_history,
        "input": user_input
    })
    return response["answer"]


def chat(user_input):
    if user_input and user_input.strip() != "":
        response = get_response(user_input)
        st.session_state.chat_history.append(
            HumanMessage(content=user_input))
        st.session_state.chat_history.append(AIMessage(content=response))

    # Dialog flow
    for message in st.session_state.chat_history:
        if isinstance(message, AIMessage):
            with st.chat_message("AI"):
                st.write(message.content)
        elif isinstance(message, HumanMessage):
            with st.chat_message("Human"):
                st.write(message.content)


def get_chat_history():
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = [
            AIMessage(content="Hello! How can I help you?")
        ]
        return st.session_state.chat_history


# UI Config
logo = Image.open("assets/logo_harve.png")
st.set_page_config(page_title="HarveGPT", page_icon=logo, layout="wide")
st.title("HarveGPT")


# Sidebar
with st.sidebar:
    st.header("Options")
    url = st.text_input("Enter Website or YouTube URL")
    uploaded_pdf = st.file_uploader("Upload a PDF", type=["pdf"])
    start_button = st.button("Start Chat")

# Options to start chat
if not url or url.strip() == "" or url is None:
    if uploaded_pdf is not None:
        chat_history = get_chat_history()

        if "vec_store" not in st.session_state:
            st.session_state.vec_store = create_vectorstore_from_pdf(
                uploaded_pdf)

        user_input = st.chat_input("Type a message...")
        chat(user_input)

    else:
        st.success("👈  Please provide Harve with a source to start the chat.")

else:
    try:
        if "youtube.com" in url or "youtu.be" in url:
            data = extract_transcript_from_youtube_url(url)
        else:
            data = extract_data_from_url(url)

    except Exception as e:
        st.warning(
            f"An error occurred: {e} Enter a valid link to continue.")
        st.stop()

    # Use `st.session_state`` to store chat history and avoid reinitializing the entire session
    chat_history = get_chat_history()

    if "vec_store" not in st.session_state:
        st.session_state.vec_store = create_vectorstore_from_data(data)

    # Chat input
    user_input = st.chat_input("Type a message...")
    chat(user_input)