manabb commited on
Commit
e4e5c5c
Β·
verified Β·
1 Parent(s): 8eb71cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -253
app.py CHANGED
@@ -1,266 +1,87 @@
1
-
 
2
  import gradio as gr
3
- from langchain.document_loaders import PyPDFLoader, DirectoryLoader
4
- from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain.embeddings import HuggingFaceEmbeddings
6
  from langchain.vectorstores import FAISS
7
- from langchain.llms import HuggingFaceHub
 
 
8
  from langchain.chains import RetrievalQA
9
- from langchain.prompts import PromptTemplate
10
- import os
11
- import tempfile
12
- import datetime
13
 
14
- class EnhancedPDFChatbot:
15
- def __init__(self):
16
- self.vectorstore = None
17
- self.qa_chain = None
18
- self.embeddings = HuggingFaceEmbeddings()
19
- self.is_ready = False
20
- self.chat_history = []
21
-
22
- def process_pdf(self, pdf_file):
23
- """Process uploaded PDF file with enhanced error handling"""
24
- try:
25
- if pdf_file is None:
26
- return "Please select a PDF file first!"
27
-
28
- # Save uploaded file
29
- with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
30
- tmp_file.write(pdf_file)
31
- tmp_path = tmp_file.name
32
-
33
- # Load and process PDF
34
- loader = PyPDFLoader(tmp_path)
35
- documents = loader.load()
36
-
37
- # Clean up
38
- os.unlink(tmp_path)
39
-
40
- if not documents:
41
- return "No content could be extracted from the PDF."
42
-
43
- # Split text
44
- text_splitter = RecursiveCharacterTextSplitter(
45
- chunk_size=800,
46
- chunk_overlap=150,
47
- length_function=len,
48
- )
49
-
50
- chunks = text_splitter.split_documents(documents)
51
-
52
- # Create vector store
53
- self.vectorstore = FAISS.from_documents(chunks, self.embeddings)
54
- self.setup_qa_chain()
55
-
56
- self.is_ready = True
57
- self.chat_history = []
58
-
59
- return f"βœ… Success! Processed {len(documents)} pages into {len(chunks)} chunks. You can now ask questions!"
60
-
61
- except Exception as e:
62
- return f"❌ Error: {str(e)}"
63
-
64
- def setup_qa_chain(self):
65
- """Setup QA chain with enhanced prompt"""
66
- llm = HuggingFaceHub(
67
- repo_id="google/flan-t5-small",
68
- model_kwargs={"temperature": 0.2, "max_length": 512, "repetition_penalty": 1.1}
69
- )
70
-
71
- prompt_template = """As an AI assistant, provide accurate answers based on the given context.
72
 
73
- CONTEXT:
74
- {context}
75
 
76
- QUESTION:
77
- {question}
 
 
78
 
79
- INSTRUCTIONS:
80
- - Answer clearly and concisely
81
- - Base your answer strictly on the context provided
82
- - If the answer isn't in the context, say "I cannot find this information in the document"
83
- - Use bullet points for lists when appropriate
84
- - Be helpful and professional
85
 
86
- ANSWER:
87
- """
88
-
89
- PROMPT = PromptTemplate(
90
- template=prompt_template,
91
- input_variables=["context", "question"]
92
- )
93
-
94
- self.qa_chain = RetrievalQA.from_chain_type(
95
- llm=llm,
96
- chain_type="stuff",
97
- retriever=self.vectorstore.as_retriever(
98
- search_type="similarity",
99
- search_kwargs={"k": 4}
100
- ),
101
- chain_type_kwargs={"prompt": PROMPT},
102
- return_source_documents=True
103
- )
104
-
105
- def ask_question(self, question, history):
106
- """Ask question with enhanced response formatting"""
107
- if not self.is_ready:
108
- return "Please upload and process a PDF first!", history
109
-
110
- if not question.strip():
111
- return "", history
112
-
113
- try:
114
- # Add timestamp
115
- timestamp = datetime.datetime.now().strftime("%H:%M:%S")
116
-
117
- result = self.qa_chain({"query": question})
118
- answer = result["result"]
119
-
120
- # Format response
121
- formatted_response = f"**{timestamp}**\n\n{answer}\n\n---\n**Sources:**"
122
-
123
- for i, doc in enumerate(result["source_documents"][:3]):
124
- page_num = doc.metadata.get('page', 'N/A') + 1 # Convert to 1-indexed
125
- content = doc.page_content.replace('\n', ' ').strip()
126
- preview = content[:120] + "..." if len(content) > 120 else content
127
- formatted_response += f"\nβ€’ Page {page_num}: {preview}"
128
-
129
- # Update history
130
- history.append((question, formatted_response))
131
- self.chat_history = history
132
-
133
- return "", history
134
-
135
- except Exception as e:
136
- error_msg = f"Error processing your question: {str(e)}"
137
- history.append((question, error_msg))
138
- return "", history
139
-
140
- def clear_chat(self):
141
- """Clear chat history"""
142
- self.chat_history = []
143
- return []
144
 
145
- # Create enhanced chatbot
146
- enhanced_chatbot = EnhancedPDFChatbot()
 
147
 
148
- # Create enhanced Gradio interface
149
- with gr.Blocks(title="Enhanced PDF Chatbot", theme=gr.themes.Default()) as enhanced_demo:
150
- gr.Markdown("""
151
- # πŸš€ Enhanced PDF Chatbot Agent
152
- **Upload a PDF document and have a conversation with AI about its content!**
153
- """)
154
-
155
- with gr.Row():
156
- with gr.Column(scale=1):
157
- with gr.Group():
158
- gr.Markdown("### πŸ“„ Document Upload")
159
- pdf_input = gr.File(
160
- label="Upload PDF File",
161
- file_types=[".pdf"],
162
- type="binary"
163
- )
164
- upload_btn = gr.Button("Process Document", variant="primary")
165
- status_output = gr.Textbox(label="Status", interactive=False)
166
-
167
- with gr.Group():
168
- gr.Markdown("### βš™οΈ Settings")
169
- chunk_size = gr.Slider(
170
- minimum=500,
171
- maximum=2000,
172
- value=800,
173
- step=100,
174
- label="Chunk Size"
175
- )
176
- temperature = gr.Slider(
177
- minimum=0.1,
178
- maximum=1.0,
179
- value=0.2,
180
- step=0.1,
181
- label="Temperature"
182
- )
183
-
184
- with gr.Column(scale=2):
185
- gr.Markdown("### πŸ’¬ Chat Interface")
186
- chatbot = gr.Chatbot(height=450, show_copy_button=True)
187
-
188
- with gr.Row():
189
- question_box = gr.Textbox(
190
- placeholder="Ask a question about the PDF...",
191
- label="Your Question",
192
- scale=4
193
- )
194
- ask_btn = gr.Button("Ask", scale=1)
195
-
196
- with gr.Row():
197
- clear_btn = gr.Button("Clear Chat", variant="secondary")
198
- export_btn = gr.Button("Export Chat", variant="secondary")
199
-
200
- # Examples
201
- gr.Examples(
202
- examples=[
203
- "What is the main purpose of this document?",
204
- "Summarize the key points in bullet form",
205
- "What are the main findings or conclusions?",
206
- "List any recommendations mentioned"
207
- ],
208
- inputs=question_box,
209
- label="Example Questions"
210
- )
211
-
212
- # Event handlers
213
- upload_btn.click(
214
- fn=enhanced_chatbot.process_pdf,
215
- inputs=pdf_input,
216
- outputs=status_output
217
- )
218
-
219
- def ask_question_wrapper(question, history):
220
- return enhanced_chatbot.ask_question(question, history)
221
-
222
- ask_btn.click(
223
- fn=ask_question_wrapper,
224
- inputs=[question_box, chatbot],
225
- outputs=[question_box, chatbot]
226
- )
227
-
228
- question_box.submit(
229
- fn=ask_question_wrapper,
230
- inputs=[question_box, chatbot],
231
- outputs=[question_box, chatbot]
232
- )
233
-
234
- clear_btn.click(
235
- fn=enhanced_chatbot.clear_chat,
236
- inputs=[],
237
- outputs=chatbot
238
- )
239
-
240
- # Export functionality
241
- def export_chat():
242
- if not enhanced_chatbot.chat_history:
243
- return "No chat history to export!"
244
-
245
- export_text = "PDF Chatbot Conversation Export\n"
246
- export_text += "=" * 40 + "\n\n"
247
-
248
- for i, (question, answer) in enumerate(enhanced_chatbot.chat_history, 1):
249
- export_text += f"Q{i}: {question}\n"
250
- export_text += f"A{i}: {answer}\n"
251
- export_text += "-" * 30 + "\n"
252
-
253
- return export_text
254
-
255
- export_btn.click(
256
- fn=export_chat,
257
- inputs=[],
258
- outputs=gr.Textbox(label="Exported Chat", lines=20)
259
  )
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  if __name__ == "__main__":
262
- enhanced_demo.launch(
263
- server_name="0.0.0.0",
264
- server_port=7860,
265
- share=True
266
- )
 
1
+ # app.py
2
+ import os
3
  import gradio as gr
4
+
 
 
5
  from langchain.vectorstores import FAISS
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.document_loaders import TextLoader
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
  from langchain.chains import RetrievalQA
10
+ from langchain.llms import HuggingFacePipeline
11
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
12
 
13
+ # Optional: Set HF Token if needed
14
+ # os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'hf_XXXX'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Initialize embedding model
17
+ embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
18
 
19
+ # Load HF model (lightweight for CPU)
20
+ model_name = "google/flan-t5-small"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
23
 
24
+ # Wrap in pipeline
25
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
26
+ llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
27
 
28
+ def process_file(file_path):
29
+ # Load & split document
30
+ loader = TextLoader(file_path)
31
+ documents = loader.load()
32
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
33
+ docs = text_splitter.split_documents(documents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Create vector DB
36
+ vector_db = FAISS.from_documents(docs, embedding_model)
37
+ retriever = vector_db.as_retriever()
38
 
39
+ # Setup RetrievalQA chain
40
+ qa_chain = RetrievalQA.from_chain_type(
41
+ llm=llm,
42
+ chain_type="stuff",
43
+ retriever=retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
 
46
+ return qa_chain
47
+
48
+ # Store the QA chain globally (across UI events)
49
+ qa_chain = None
50
+
51
+ def upload_and_prepare(file):
52
+ global qa_chain
53
+ # qa_chain = process_file(file)
54
+ qa_chain = process_file(file.name)
55
+ return "βœ… Document processed. You can now ask questions!"
56
+
57
+ def ask_question(query):
58
+ if not qa_chain:
59
+ return "❌ Please upload a document first."
60
+ response = qa_chain.invoke({"query": query})
61
+ return response["result"]
62
+
63
+ # Gradio UI
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("## 🧠 Ask Questions About Your Document (LangChain + Hugging Face)")
66
+
67
+ with gr.Row():
68
+ file_input = gr.File(label="πŸ“„ Upload .txt File", type="filepath")
69
+ upload_btn = gr.Button("πŸ”„ Process Document")
70
+
71
+ upload_output = gr.Textbox(label="πŸ“ Status", interactive=False)
72
+
73
+ with gr.Row():
74
+ query_input = gr.Textbox(label="❓ Your Question")
75
+ query_btn = gr.Button("🧠 Get Answer")
76
+
77
+ answer_output = gr.Textbox(label="βœ… Answer", lines=4)
78
+
79
+ upload_btn.click(upload_and_prepare, inputs=file_input, outputs=upload_output)
80
+ query_btn.click(ask_question, inputs=query_input, outputs=answer_output)
81
+
82
+ # For local dev use: demo.launch()
83
+ # For HF Spaces
84
  if __name__ == "__main__":
85
+ demo.launch()
86
+
87
+