Update app.py
Browse files
app.py
CHANGED
|
@@ -31,7 +31,7 @@ if "file_names" not in st.session_state:
|
|
| 31 |
class PDFQAAssistant:
|
| 32 |
def __init__(self,
|
| 33 |
hf_token: str = None,
|
| 34 |
-
model_name: str = "
|
| 35 |
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
| 36 |
"""
|
| 37 |
Initialize the PDF Q&A Assistant with Hugging Face models.
|
|
@@ -52,7 +52,7 @@ class PDFQAAssistant:
|
|
| 52 |
self.llm = HuggingFaceEndpoint(
|
| 53 |
repo_id=model_name,
|
| 54 |
huggingfacehub_api_token=hf_token,
|
| 55 |
-
max_length=
|
| 56 |
temperature=0.5
|
| 57 |
)
|
| 58 |
|
|
@@ -64,8 +64,8 @@ class PDFQAAssistant:
|
|
| 64 |
|
| 65 |
# Initialize text splitter for chunking documents
|
| 66 |
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 67 |
-
chunk_size=
|
| 68 |
-
chunk_overlap=
|
| 69 |
length_function=len
|
| 70 |
)
|
| 71 |
|
|
@@ -274,20 +274,25 @@ def main():
|
|
| 274 |
if use_manual_token:
|
| 275 |
hf_token = st.text_input("Enter Hugging Face API Token:", type="password")
|
| 276 |
|
| 277 |
-
# Model selection
|
| 278 |
st.subheader("Model Settings")
|
| 279 |
model_name = st.selectbox(
|
| 280 |
"Select LLM model:",
|
| 281 |
-
[
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
| 284 |
index=0
|
| 285 |
)
|
| 286 |
|
| 287 |
embedding_model = st.selectbox(
|
| 288 |
"Select Embedding model:",
|
| 289 |
-
[
|
| 290 |
-
|
|
|
|
|
|
|
| 291 |
index=0
|
| 292 |
)
|
| 293 |
|
|
@@ -314,29 +319,33 @@ def main():
|
|
| 314 |
# Process each uploaded file
|
| 315 |
for pdf_file in uploaded_files:
|
| 316 |
file_name = pdf_file.name
|
| 317 |
-
st.session_state.file_names
|
|
|
|
| 318 |
assistant.process_pdf(pdf_file, file_name)
|
| 319 |
|
| 320 |
# Store the assistant in session state
|
| 321 |
st.session_state.assistant = assistant
|
| 322 |
except Exception as e:
|
| 323 |
st.error(f"Error initializing assistant: {e}")
|
|
|
|
| 324 |
|
| 325 |
# Document management
|
| 326 |
-
if st.session_state.document_processed:
|
| 327 |
st.subheader("Document Management")
|
| 328 |
|
| 329 |
if st.button("Clear Chat History"):
|
| 330 |
-
st.session_state
|
|
|
|
| 331 |
st.session_state.chat_history = []
|
| 332 |
st.success("Chat history cleared!")
|
| 333 |
|
| 334 |
if st.button("Generate Document Summary"):
|
| 335 |
-
|
| 336 |
-
|
|
|
|
| 337 |
|
| 338 |
# Main area for chat interface
|
| 339 |
-
if not st.session_state.document_processed:
|
| 340 |
st.info("👈 Please upload and process a PDF document to get started.")
|
| 341 |
|
| 342 |
# Display demo information
|
|
@@ -405,6 +414,7 @@ def main():
|
|
| 405 |
})
|
| 406 |
except Exception as e:
|
| 407 |
st.error(f"Error getting response: {e}")
|
|
|
|
| 408 |
|
| 409 |
if __name__ == "__main__":
|
| 410 |
main()
|
|
|
|
| 31 |
class PDFQAAssistant:
|
| 32 |
def __init__(self,
|
| 33 |
hf_token: str = None,
|
| 34 |
+
model_name: str = "google/flan-t5-base", # Changed to a more accessible model
|
| 35 |
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
| 36 |
"""
|
| 37 |
Initialize the PDF Q&A Assistant with Hugging Face models.
|
|
|
|
| 52 |
self.llm = HuggingFaceEndpoint(
|
| 53 |
repo_id=model_name,
|
| 54 |
huggingfacehub_api_token=hf_token,
|
| 55 |
+
max_length=512, # Reduced for smaller models
|
| 56 |
temperature=0.5
|
| 57 |
)
|
| 58 |
|
|
|
|
| 64 |
|
| 65 |
# Initialize text splitter for chunking documents
|
| 66 |
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 67 |
+
chunk_size=800, # Smaller chunks for better processing
|
| 68 |
+
chunk_overlap=150,
|
| 69 |
length_function=len
|
| 70 |
)
|
| 71 |
|
|
|
|
| 274 |
if use_manual_token:
|
| 275 |
hf_token = st.text_input("Enter Hugging Face API Token:", type="password")
|
| 276 |
|
| 277 |
+
# Model selection with open-source models
|
| 278 |
st.subheader("Model Settings")
|
| 279 |
model_name = st.selectbox(
|
| 280 |
"Select LLM model:",
|
| 281 |
+
[
|
| 282 |
+
"google/flan-t5-base", # Smaller, more accessible model
|
| 283 |
+
"google/flan-t5-small", # Even smaller model
|
| 284 |
+
"facebook/bart-large-cnn", # Good for summarization
|
| 285 |
+
"distilbert-base-uncased" # Lightweight model
|
| 286 |
+
],
|
| 287 |
index=0
|
| 288 |
)
|
| 289 |
|
| 290 |
embedding_model = st.selectbox(
|
| 291 |
"Select Embedding model:",
|
| 292 |
+
[
|
| 293 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 294 |
+
"sentence-transformers/paraphrase-MiniLM-L3-v2" # Smaller embedding model
|
| 295 |
+
],
|
| 296 |
index=0
|
| 297 |
)
|
| 298 |
|
|
|
|
| 319 |
# Process each uploaded file
|
| 320 |
for pdf_file in uploaded_files:
|
| 321 |
file_name = pdf_file.name
|
| 322 |
+
if file_name not in st.session_state.file_names:
|
| 323 |
+
st.session_state.file_names.append(file_name)
|
| 324 |
assistant.process_pdf(pdf_file, file_name)
|
| 325 |
|
| 326 |
# Store the assistant in session state
|
| 327 |
st.session_state.assistant = assistant
|
| 328 |
except Exception as e:
|
| 329 |
st.error(f"Error initializing assistant: {e}")
|
| 330 |
+
st.error("Try selecting a different model or check your token permissions.")
|
| 331 |
|
| 332 |
# Document management
|
| 333 |
+
if st.session_state.get("document_processed", False):
|
| 334 |
st.subheader("Document Management")
|
| 335 |
|
| 336 |
if st.button("Clear Chat History"):
|
| 337 |
+
if "assistant" in st.session_state:
|
| 338 |
+
st.session_state.assistant.clear_memory()
|
| 339 |
st.session_state.chat_history = []
|
| 340 |
st.success("Chat history cleared!")
|
| 341 |
|
| 342 |
if st.button("Generate Document Summary"):
|
| 343 |
+
if "assistant" in st.session_state and len(st.session_state.file_names) > 0:
|
| 344 |
+
get_document_summary(st.session_state.assistant,
|
| 345 |
+
st.session_state.file_names[0])
|
| 346 |
|
| 347 |
# Main area for chat interface
|
| 348 |
+
if not st.session_state.get("document_processed", False):
|
| 349 |
st.info("👈 Please upload and process a PDF document to get started.")
|
| 350 |
|
| 351 |
# Display demo information
|
|
|
|
| 414 |
})
|
| 415 |
except Exception as e:
|
| 416 |
st.error(f"Error getting response: {e}")
|
| 417 |
+
st.error("Please try a different question or model.")
|
| 418 |
|
| 419 |
if __name__ == "__main__":
|
| 420 |
main()
|