DocAI-chatbot / app.py
aach456's picture
Update app.py
19e9a6c verified
import os
import tempfile
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, UnstructuredPowerPointLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain_community.llms import HuggingFacePipeline
# Configure environment
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "google/flan-t5-large"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
THRESHOLD = 0.7 # Relevance threshold for retrieval
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
TEMPERATURE = 0.1
MAX_NEW_TOKENS = 512
TOP_K = 3 # Number of chunks to retrieve
# Store for conversation history
conversation_history = {}
current_session_id = None
current_document_store = None
current_document_name = None
FILE_EXTENSIONS = {
".pdf": PyPDFLoader,
".txt": TextLoader,
".docx": Docx2txtLoader,
".pptx": UnstructuredPowerPointLoader,
}
class DocumentAIBot:
def __init__(self):
self.setup_models()
def setup_models(self):
print("Setting up models...")
# Set up embedding model
self.embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={"device": DEVICE},
encode_kwargs={"normalize_embeddings": True}
)
# Set up LLM model
self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
self.llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL).to(DEVICE)
# Create text generation pipeline
self.text_generation_pipeline = pipeline(
"text2text-generation",
model=self.llm_model,
tokenizer=self.tokenizer,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
device=0 if DEVICE == "cuda" else -1
)
# Create HuggingFace pipeline for LangChain
self.llm = HuggingFacePipeline(pipeline=self.text_generation_pipeline)
# Text splitter for document chunking
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len
)
print("Models loaded successfully!")
def process_document(self, file_path):
"""Process a document and create a vector store."""
print(f"Processing document: {file_path}")
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension not in FILE_EXTENSIONS:
raise ValueError(f"Unsupported file format: {file_extension}")
# Select appropriate loader
loader_class = FILE_EXTENSIONS[file_extension]
loader = loader_class(file_path)
# Load and split the document
documents = loader.load()
chunks = self.text_splitter.split_documents(documents)
if not chunks:
raise ValueError("No content extracted from the document")
print(f"Document split into {len(chunks)} chunks")
# Create vector store
vector_store = FAISS.from_documents(chunks, self.embedding_model)
return vector_store
def setup_retrieval_chain(self, vector_store):
"""Set up the retrieval chain with the vector store."""
retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": TOP_K,
"score_threshold": THRESHOLD
}
)
chain = ConversationalRetrievalChain.from_llm(
llm=self.llm,
retriever=retriever,
return_source_documents=True,
verbose=True
)
return chain
def get_answer(self, question, session_id, vector_store, chat_history):
"""Get answer for a question using the retrieval chain."""
if not question.strip():
return "Please enter a question related to the document.", chat_history
# Setup retrieval chain if needed
retrieval_chain = self.setup_retrieval_chain(vector_store)
# Format chat history for the model
formatted_chat_history = [(q, a) for q, a in chat_history]
# Get response from the chain
response = retrieval_chain(
{"question": question, "chat_history": formatted_chat_history}
)
answer = response["answer"]
source_documents = response.get("source_documents", [])
# Format answer with source information
if source_documents:
source_info = "\n\nSources:"
seen_sources = set()
for doc in source_documents:
source = doc.metadata.get("source", "Unknown source")
page = doc.metadata.get("page", "Unknown page")
source_key = f"{source}-{page}"
if source_key not in seen_sources:
seen_sources.add(source_key)
if source == "Unknown source":
source_info += f"\n- Document chunk (page {page})"
else:
source_info += f"\n- {os.path.basename(source)} (page {page})"
answer += source_info
return answer, chat_history + [(question, answer)]
def generate_session_id():
"""Generate a unique session ID."""
import uuid
return str(uuid.uuid4())
def process_uploaded_document(file_path):
"""Process an uploaded document and set up the session."""
global current_session_id, current_document_store, current_document_name, conversation_history
try:
if file_path is None:
return None, "Please upload a document first."
# In newer Gradio versions, the file input with type="filepath" returns the path directly
# No need to save the file as it's already saved by Gradio
# Extract filename for display
filename = os.path.basename(file_path)
# Create document AI bot if not already created
if not hasattr(process_uploaded_document, "bot"):
process_uploaded_document.bot = DocumentAIBot()
# Process the document
vector_store = process_uploaded_document.bot.process_document(file_path)
# Create a new session
session_id = generate_session_id()
conversation_history[session_id] = []
# Update global variables
current_session_id = session_id
current_document_store = vector_store
current_document_name = filename
return [], f"Document '{filename}' processed successfully. You can now ask questions about it."
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error processing document: {str(e)}"
def clear_conversation():
"""Clear the conversation history for the current session."""
global conversation_history, current_session_id
if current_session_id and current_session_id in conversation_history:
conversation_history[current_session_id] = []
return [], f"Conversation cleared. You can continue asking questions about '{current_document_name}'."
def answer_question(question, history):
"""Answer a question about the current document."""
global current_session_id, current_document_store, conversation_history
if not current_document_store:
return "", history + [(question, "Please upload a document first.")]
if not hasattr(process_uploaded_document, "bot"):
return "", history + [(question, "Document AI bot not initialized. Please reload the page and try again.")]
try:
# Get current chat history
chat_history = conversation_history.get(current_session_id, [])
# Get answer
answer, updated_history = process_uploaded_document.bot.get_answer(
question,
current_session_id,
current_document_store,
chat_history
)
# Update conversation history
conversation_history[current_session_id] = updated_history
# Update the display history
history = history + [(question, answer)]
return "", history
except Exception as e:
import traceback
traceback.print_exc()
return "", history + [(question, f"Error generating answer: {str(e)}")]
def build_interface():
"""Build and launch the Gradio interface."""
# Define the Gradio blocks
with gr.Blocks(title="Document AI Chatbot") as interface:
gr.Markdown("# 📄 Document AI Chatbot")
gr.Markdown("Upload a document (PDF, TXT, DOCX, PPTX) and ask questions about its content.")
with gr.Row():
with gr.Column(scale=1):
# Document upload and processing section
file_input = gr.File(
label="Upload Document",
file_types=[".pdf", ".txt", ".docx", ".pptx"],
type="filepath" # This returns the file path directly
)
upload_button = gr.Button("Process Document", variant="primary")
upload_status = gr.Textbox(label="Upload Status", interactive=False)
clear_button = gr.Button("Clear Conversation")
gr.Markdown("### System Information")
system_info = gr.Markdown(f"""
- Embedding Model: {EMBEDDING_MODEL}
- Language Model: {LLM_MODEL}
- Running on: {DEVICE}
- Chunk Size: {CHUNK_SIZE}
- Relevance Threshold: {THRESHOLD}
""")
with gr.Column(scale=2):
# Chat interface
chatbot = gr.Chatbot(
label="Conversation",
height=500,
show_label=True,
)
with gr.Row():
question_input = gr.Textbox(
label="Ask a question about the document",
placeholder="What is the main topic of this document?",
lines=2,
show_label=True
)
submit_button = gr.Button("Submit", variant="primary")
# Set up event handlers
upload_button.click(
process_uploaded_document,
inputs=[file_input],
outputs=[chatbot, upload_status]
)
submit_button.click(
answer_question,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot]
)
question_input.submit(
answer_question,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot]
)
clear_button.click(
clear_conversation,
inputs=[],
outputs=[chatbot, upload_status]
)
return interface
# Main execution
if __name__ == "__main__":
demo = build_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)