File size: 4,771 Bytes
4e7dff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import sys
import os
import re
import streamlit as st
import time

sys.path.append(os.path.abspath("."))
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import OpenAI
from langchain.document_loaders import UnstructuredPDFLoader
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import NLTKTextSplitter
from patent_downloader import PatentDownloader

PERSISTED_DIRECTORY = "."

# Fetch API key securely from the environment
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
    st.stop()

def load_docs(document_path):
    loader = UnstructuredPDFLoader(document_path)
    documents = loader.load()
    text_splitter = NLTKTextSplitter(chunk_size=1000)
    return text_splitter.split_documents(documents)

def already_indexed(vectordb, file_name):
    indexed_sources = set(
        x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
    )
    return file_name in indexed_sources

def load_chain(file_name=None):
    loaded_patent = st.session_state.get("LOADED_PATENT")

    vectordb = Chroma(
        persist_directory=PERSISTED_DIRECTORY,
        embedding_function=HuggingFaceEmbeddings(),
    )
    if loaded_patent == file_name or already_indexed(vectordb, file_name):
        st.write("Already indexed")
    else:
        vectordb.delete_collection()
        docs = load_docs(file_name)
        st.write("Length: ", len(docs))

        vectordb = Chroma.from_documents(
            docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
        )
        vectordb.persist()
        st.session_state["LOADED_PATENT"] = file_name

    memory = ConversationBufferMemory(
        memory_key="chat_history",
        return_messages=True,
        input_key="question",
        output_key="answer",
    )
    return ConversationalRetrievalChain.from_llm(
        OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
        vectordb.as_retriever(search_kwargs={"k": 3}),
        return_source_documents=False,
        memory=memory,
    )

def extract_patent_number(url):
    pattern = r"/patent/([A-Z]{2}\d+)"
    match = re.search(pattern, url)
    return match.group(1) if match else None

def download_pdf(patent_number):
    patent_downloader = PatentDownloader()
    patent_downloader.download(patent=patent_number)
    return f"{patent_number}.pdf"

if __name__ == "__main__":
    st.set_page_config(
        page_title="Patent Chat: Google Patents Chat Demo",
        page_icon="πŸ“–",
        layout="wide",
        initial_sidebar_state="expanded",
    )
    st.header("πŸ“– Patent Chat: Google Patents Chat Demo")
    
    # Allow user to input the Google patent link
    patent_link = st.text_input("Enter Google Patent Link:", key="PATENT_LINK")
    
    if not patent_link:
        st.warning("Please enter a Google patent link to proceed.")
        st.stop()
    else:
        st.session_state["patent_link_configured"] = True

    patent_number = extract_patent_number(patent_link)
    if not patent_number:
        st.error("Invalid patent link format. Please provide a valid Google patent link.")
        st.stop()

    st.write("Patent number: ", patent_number)

    pdf_path = f"{patent_number}.pdf"
    if os.path.isfile(pdf_path):
        st.write("File already downloaded.")
    else:
        st.write("Downloading patent file...")
        pdf_path = download_pdf(patent_number)
        st.write("File downloaded.")

    chain = load_chain(pdf_path)

    if "messages" not in st.session_state:
        st.session_state["messages"] = [
            {"role": "assistant", "content": "How can I help you?"}
        ]

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if user_input := st.chat_input("What is your question?"):
        st.session_state.messages.append({"role": "user", "content": user_input})
        with st.chat_message("user"):
            st.markdown(user_input)

        with st.chat_message("assistant"):
            message_placeholder = st.empty()
            full_response = ""

        with st.spinner("CHAT-BOT is at Work ..."):
            assistant_response = chain({"question": user_input})
            for chunk in assistant_response["answer"].split():
                full_response += chunk + " "
                time.sleep(0.05)
                message_placeholder.markdown(full_response + "β–Œ")
        st.session_state.messages.append(
            {"role": "assistant", "content": full_response}
        )