Spaces:
Runtime error
Runtime error
File size: 4,771 Bytes
4e7dff1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import sys
import os
import re
import streamlit as st
import time
sys.path.append(os.path.abspath("."))
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import OpenAI
from langchain.document_loaders import UnstructuredPDFLoader
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import NLTKTextSplitter
from patent_downloader import PatentDownloader
PERSISTED_DIRECTORY = "."
# Fetch API key securely from the environment
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
st.stop()
def load_docs(document_path):
loader = UnstructuredPDFLoader(document_path)
documents = loader.load()
text_splitter = NLTKTextSplitter(chunk_size=1000)
return text_splitter.split_documents(documents)
def already_indexed(vectordb, file_name):
indexed_sources = set(
x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
)
return file_name in indexed_sources
def load_chain(file_name=None):
loaded_patent = st.session_state.get("LOADED_PATENT")
vectordb = Chroma(
persist_directory=PERSISTED_DIRECTORY,
embedding_function=HuggingFaceEmbeddings(),
)
if loaded_patent == file_name or already_indexed(vectordb, file_name):
st.write("Already indexed")
else:
vectordb.delete_collection()
docs = load_docs(file_name)
st.write("Length: ", len(docs))
vectordb = Chroma.from_documents(
docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
)
vectordb.persist()
st.session_state["LOADED_PATENT"] = file_name
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
input_key="question",
output_key="answer",
)
return ConversationalRetrievalChain.from_llm(
OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
vectordb.as_retriever(search_kwargs={"k": 3}),
return_source_documents=False,
memory=memory,
)
def extract_patent_number(url):
pattern = r"/patent/([A-Z]{2}\d+)"
match = re.search(pattern, url)
return match.group(1) if match else None
def download_pdf(patent_number):
patent_downloader = PatentDownloader()
patent_downloader.download(patent=patent_number)
return f"{patent_number}.pdf"
if __name__ == "__main__":
st.set_page_config(
page_title="Patent Chat: Google Patents Chat Demo",
page_icon="π",
layout="wide",
initial_sidebar_state="expanded",
)
st.header("π Patent Chat: Google Patents Chat Demo")
# Allow user to input the Google patent link
patent_link = st.text_input("Enter Google Patent Link:", key="PATENT_LINK")
if not patent_link:
st.warning("Please enter a Google patent link to proceed.")
st.stop()
else:
st.session_state["patent_link_configured"] = True
patent_number = extract_patent_number(patent_link)
if not patent_number:
st.error("Invalid patent link format. Please provide a valid Google patent link.")
st.stop()
st.write("Patent number: ", patent_number)
pdf_path = f"{patent_number}.pdf"
if os.path.isfile(pdf_path):
st.write("File already downloaded.")
else:
st.write("Downloading patent file...")
pdf_path = download_pdf(patent_number)
st.write("File downloaded.")
chain = load_chain(pdf_path)
if "messages" not in st.session_state:
st.session_state["messages"] = [
{"role": "assistant", "content": "How can I help you?"}
]
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if user_input := st.chat_input("What is your question?"):
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
with st.spinner("CHAT-BOT is at Work ..."):
assistant_response = chain({"question": user_input})
for chunk in assistant_response["answer"].split():
full_response += chunk + " "
time.sleep(0.05)
message_placeholder.markdown(full_response + "β")
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
|