Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import re | |
| import streamlit as st | |
| import time | |
| sys.path.append(os.path.abspath(".")) | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.llms import OpenAI | |
| from langchain.document_loaders import UnstructuredPDFLoader | |
| from langchain.vectorstores import Chroma | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import NLTKTextSplitter | |
| from patent_downloader import PatentDownloader | |
| PERSISTED_DIRECTORY = "." | |
| # Fetch API key securely from the environment | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| if not OPENAI_API_KEY: | |
| st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.") | |
| st.stop() | |
| def load_docs(document_path): | |
| loader = UnstructuredPDFLoader(document_path) | |
| documents = loader.load() | |
| text_splitter = NLTKTextSplitter(chunk_size=1000) | |
| return text_splitter.split_documents(documents) | |
| def already_indexed(vectordb, file_name): | |
| indexed_sources = set( | |
| x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"] | |
| ) | |
| return file_name in indexed_sources | |
| def load_chain(file_name=None): | |
| loaded_patent = st.session_state.get("LOADED_PATENT") | |
| vectordb = Chroma( | |
| persist_directory=PERSISTED_DIRECTORY, | |
| embedding_function=HuggingFaceEmbeddings(), | |
| ) | |
| if loaded_patent == file_name or already_indexed(vectordb, file_name): | |
| st.write("Already indexed") | |
| else: | |
| vectordb.delete_collection() | |
| docs = load_docs(file_name) | |
| st.write("Length: ", len(docs)) | |
| vectordb = Chroma.from_documents( | |
| docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY | |
| ) | |
| vectordb.persist() | |
| st.session_state["LOADED_PATENT"] = file_name | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True, | |
| input_key="question", | |
| output_key="answer", | |
| ) | |
| return ConversationalRetrievalChain.from_llm( | |
| OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY), | |
| vectordb.as_retriever(search_kwargs={"k": 3}), | |
| return_source_documents=False, | |
| memory=memory, | |
| ) | |
| def extract_patent_number(url): | |
| pattern = r"/patent/([A-Z]{2}\d+)" | |
| match = re.search(pattern, url) | |
| return match.group(1) if match else None | |
| def download_pdf(patent_number): | |
| patent_downloader = PatentDownloader() | |
| patent_downloader.download(patent=patent_number) | |
| return f"{patent_number}.pdf" | |
| if __name__ == "__main__": | |
| st.set_page_config( | |
| page_title="Patent Chat: Google Patents Chat Demo", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| st.header("π Patent Chat: Google Patents Chat Demo") | |
| # Allow user to input the Google patent link | |
| patent_link = st.text_input("Enter Google Patent Link:", key="PATENT_LINK") | |
| if not patent_link: | |
| st.warning("Please enter a Google patent link to proceed.") | |
| st.stop() | |
| else: | |
| st.session_state["patent_link_configured"] = True | |
| patent_number = extract_patent_number(patent_link) | |
| if not patent_number: | |
| st.error("Invalid patent link format. Please provide a valid Google patent link.") | |
| st.stop() | |
| st.write("Patent number: ", patent_number) | |
| pdf_path = f"{patent_number}.pdf" | |
| if os.path.isfile(pdf_path): | |
| st.write("File already downloaded.") | |
| else: | |
| st.write("Downloading patent file...") | |
| pdf_path = download_pdf(patent_number) | |
| st.write("File downloaded.") | |
| chain = load_chain(pdf_path) | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [ | |
| {"role": "assistant", "content": "How can I help you?"} | |
| ] | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if user_input := st.chat_input("What is your question?"): | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| with st.chat_message("assistant"): | |
| message_placeholder = st.empty() | |
| full_response = "" | |
| with st.spinner("CHAT-BOT is at Work ..."): | |
| assistant_response = chain({"question": user_input}) | |
| for chunk in assistant_response["answer"].split(): | |
| full_response += chunk + " " | |
| time.sleep(0.05) | |
| message_placeholder.markdown(full_response + "β") | |
| st.session_state.messages.append( | |
| {"role": "assistant", "content": full_response} | |
| ) | |