nikhmr1235 commited on
Commit
f5dabf5
·
verified ·
1 Parent(s): 965a01e

fix issue of seeing context from previously uploaded doc due to gradio state being shared incorrectly

Browse files
Files changed (1) hide show
  1. app.py +76 -88
app.py CHANGED
@@ -22,101 +22,86 @@ LLM_MODEL = "gemini-1.5-flash"
22
  EMBEDDING_MODEL = "models/embedding-001"
23
  CHROMA_DB_PATH = tempfile.gettempdir() + "/chroma_db"
24
 
25
- class PDFChatbot:
26
  def __init__(self):
27
- self.state = SessionState()
28
-
29
- async def process_pdf(self, pdf_file):
30
- try:
31
- if self.state.is_db_ready():
32
- print("Database is already ready.")
33
- return
34
-
35
- file_size_mb = os.path.getsize(pdf_file.name) / (1024 * 1024)
36
- if file_size_mb >= 75:
37
- print("File size exceeds the 75 MB limit.")
38
- gr.Error("File size exceeds the 75 MB limit. Please upload a smaller PDF.")
39
- return
40
-
41
- self.state = SessionState()
42
- print("Opening PDF file...")
43
- try:
44
- doc = fitz.open(pdf_file.name)
45
- text = ""
46
- for page in doc:
47
- text += page.get_text()
48
- doc.close()
49
- except Exception as e:
50
- print(f"Error processing PDF document: {str(e)}")
51
- return
52
-
53
- print("PDF file opened successfully. Splitting text into chunks...")
54
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
55
- docs = text_splitter.create_documents([text])
56
- print("Text split into chunks successfully.")
57
-
58
- embeddings = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, google_api_key=google_api_key)
59
- self.state.db = await Chroma.afrom_documents(
60
- documents=docs,
61
- embedding=embeddings,
62
- persist_directory=self.state.vector_store_path,
63
- collection_name=self.state.session_id
64
- )
65
- print("PDF processed successfully! Database is ready.")
66
- except Exception as e:
67
- if os.path.exists(self.state.vector_store_path):
68
- shutil.rmtree(self.state.vector_store_path)
69
- print(f"An error occurred: {str(e)}")
70
 
71
  def is_db_ready(self):
72
- return self.state.db is not None
73
 
74
- async def chat_with_pdf(self, message, history):
75
- print("Chat interface called. Checking if database is ready...")
76
- if not self.is_db_ready():
77
- print("Database is not ready.")
78
- yield "Error: Database not ready."
79
  return
80
 
81
- print("Database is ready. Retrieving relevant documents...")
82
- retriever = self.state.db.as_retriever()
83
- llm = ChatGoogleGenerativeAI(model=LLM_MODEL, temperature=0.7, google_api_key=google_api_key)
84
-
85
- prompt_template = PromptTemplate(
86
- template="""
87
- You are a helpful assistant for a PDF document.
88
- Answer the user's question based on the following context.
89
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
90
- ----------------
91
- Context: {context}
92
- Question: {question}
93
- """,
94
- input_variables=["context", "question"],
95
- )
96
-
97
- rag_chain = (
98
- {"context": retriever, "question": RunnablePassthrough()}
99
- | prompt_template
100
- | llm
101
- | StrOutputParser()
102
- )
103
 
104
- response = await rag_chain.ainvoke(
105
- message
 
 
 
 
 
 
 
 
 
106
  )
107
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- class SessionState:
110
- def __init__(self):
111
- self.session_id = str(uuid.uuid4())
112
- self.db = None
113
- self.vector_store_path = os.path.join(CHROMA_DB_PATH, self.session_id)
 
114
 
115
- def is_db_ready(self):
116
- return self.db is not None
117
 
118
  with gr.Blocks(title="PDF Chatbot") as demo:
119
- chatbot = PDFChatbot()
120
 
121
  gr.Markdown(
122
  """
@@ -134,22 +119,25 @@ with gr.Blocks(title="PDF Chatbot") as demo:
134
 
135
  with gr.Row(visible=False) as chat_row:
136
  chat_interface = gr.ChatInterface(
137
- fn=chatbot.chat_with_pdf,
 
138
  chatbot=gr.Chatbot(type="messages"),
139
  textbox=gr.Textbox(placeholder="Type your question here...", scale=7),
140
  examples=[["What is the main topic of the document?"], ["Summarize the key findings."], ["Who are the authors?"]],
141
  title="Chat Interface",
142
- theme="soft"
 
143
  )
144
 
145
  async def process_and_show_chat(file):
146
- await chatbot.process_pdf(file)
147
- return gr.update(visible=True), gr.update(interactive=False)
 
148
 
149
  file_upload_input.upload(
150
  fn=process_and_show_chat,
151
  inputs=[file_upload_input],
152
- outputs=[chat_row, file_upload_input]
153
  )
154
 
155
  demo.launch()
 
22
  EMBEDDING_MODEL = "models/embedding-001"
23
  CHROMA_DB_PATH = tempfile.gettempdir() + "/chroma_db"
24
 
25
+ class SessionState:
26
  def __init__(self):
27
+ self.session_id = str(uuid.uuid4())
28
+ self.db = None
29
+ self.vector_store_path = os.path.join(CHROMA_DB_PATH, self.session_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def is_db_ready(self):
32
+ return self.db is not None
33
 
34
+ async def process_pdf(pdf_file, state: SessionState):
35
+ try:
36
+ file_size_mb = os.path.getsize(pdf_file.name) / (1024 * 1024)
37
+ if file_size_mb >= 75:
38
+ gr.Error("File size exceeds the 75 MB limit. Please upload a smaller PDF.")
39
  return
40
 
41
+ print("Opening PDF file...")
42
+ try:
43
+ doc = fitz.open(pdf_file.name)
44
+ text = ""
45
+ for page in doc:
46
+ text += page.get_text()
47
+ doc.close()
48
+ except Exception as e:
49
+ print(f"Error processing PDF document: {str(e)}")
50
+ return
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ print("PDF file opened successfully. Splitting text into chunks...")
53
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
54
+ docs = text_splitter.create_documents([text])
55
+ print("Text split into chunks successfully.")
56
+
57
+ embeddings = GoogleGenerativeAIEmbeddings(model=EMBEDDING_MODEL, google_api_key=google_api_key)
58
+ state.db = await Chroma.afrom_documents(
59
+ documents=docs,
60
+ embedding=embeddings,
61
+ persist_directory=state.vector_store_path,
62
+ collection_name=state.session_id
63
  )
64
+ print("PDF processed successfully! Database is ready.")
65
+ except Exception as e:
66
+ if os.path.exists(state.vector_store_path):
67
+ shutil.rmtree(state.vector_store_path)
68
+ print(f"An error occurred: {str(e)}")
69
+
70
+ async def chat_with_pdf(message, history, state: SessionState):
71
+ print("Chat interface called. Checking if database is ready...")
72
+ if not state or not state.is_db_ready():
73
+ print("Database is not ready.")
74
+ yield "Error: Database not ready. Please upload a PDF first."
75
+ return
76
+
77
+ print("Database is ready. Retrieving relevant documents...")
78
+ retriever = state.db.as_retriever()
79
+ llm = ChatGoogleGenerativeAI(model=LLM_MODEL, temperature=0.7, google_api_key=google_api_key)
80
+
81
+ prompt_template = PromptTemplate(
82
+ template="""
83
+ You are a helpful assistant for a PDF document.
84
+ Answer the user's question based on the following context.
85
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
86
+ ----------------
87
+ Context: {context}
88
+ Question: {question}
89
+ """,
90
+ input_variables=["context", "question"],
91
+ )
92
 
93
+ rag_chain = (
94
+ {"context": retriever, "question": RunnablePassthrough()}
95
+ | prompt_template
96
+ | llm
97
+ | StrOutputParser()
98
+ )
99
 
100
+ response = await rag_chain.ainvoke(message)
101
+ yield response
102
 
103
  with gr.Blocks(title="PDF Chatbot") as demo:
104
+ state = gr.State()
105
 
106
  gr.Markdown(
107
  """
 
119
 
120
  with gr.Row(visible=False) as chat_row:
121
  chat_interface = gr.ChatInterface(
122
+ fn=chat_with_pdf,
123
+ additional_inputs=[state],
124
  chatbot=gr.Chatbot(type="messages"),
125
  textbox=gr.Textbox(placeholder="Type your question here...", scale=7),
126
  examples=[["What is the main topic of the document?"], ["Summarize the key findings."], ["Who are the authors?"]],
127
  title="Chat Interface",
128
+ theme="soft",
129
+ type="messages"
130
  )
131
 
132
  async def process_and_show_chat(file):
133
+ new_state = SessionState()
134
+ await process_pdf(file, new_state)
135
+ return gr.update(visible=True), gr.update(interactive=False), new_state
136
 
137
  file_upload_input.upload(
138
  fn=process_and_show_chat,
139
  inputs=[file_upload_input],
140
+ outputs=[chat_row, file_upload_input, state]
141
  )
142
 
143
  demo.launch()