Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, UnstructuredPowerPointLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain_community.llms import HuggingFacePipeline | |
| # Configure environment | |
| EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| LLM_MODEL = "google/flan-t5-large" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| THRESHOLD = 0.7 # Relevance threshold for retrieval | |
| CHUNK_SIZE = 1000 | |
| CHUNK_OVERLAP = 200 | |
| TEMPERATURE = 0.1 | |
| MAX_NEW_TOKENS = 512 | |
| TOP_K = 3 # Number of chunks to retrieve | |
| # Store for conversation history | |
| conversation_history = {} | |
| current_session_id = None | |
| current_document_store = None | |
| current_document_name = None | |
| FILE_EXTENSIONS = { | |
| ".pdf": PyPDFLoader, | |
| ".txt": TextLoader, | |
| ".docx": Docx2txtLoader, | |
| ".pptx": UnstructuredPowerPointLoader, | |
| } | |
| class DocumentAIBot: | |
| def __init__(self): | |
| self.setup_models() | |
| def setup_models(self): | |
| print("Setting up models...") | |
| # Set up embedding model | |
| self.embedding_model = HuggingFaceEmbeddings( | |
| model_name=EMBEDDING_MODEL, | |
| model_kwargs={"device": DEVICE}, | |
| encode_kwargs={"normalize_embeddings": True} | |
| ) | |
| # Set up LLM model | |
| self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL) | |
| self.llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL).to(DEVICE) | |
| # Create text generation pipeline | |
| self.text_generation_pipeline = pipeline( | |
| "text2text-generation", | |
| model=self.llm_model, | |
| tokenizer=self.tokenizer, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=TEMPERATURE, | |
| device=0 if DEVICE == "cuda" else -1 | |
| ) | |
| # Create HuggingFace pipeline for LangChain | |
| self.llm = HuggingFacePipeline(pipeline=self.text_generation_pipeline) | |
| # Text splitter for document chunking | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP, | |
| length_function=len | |
| ) | |
| print("Models loaded successfully!") | |
| def process_document(self, file_path): | |
| """Process a document and create a vector store.""" | |
| print(f"Processing document: {file_path}") | |
| file_extension = os.path.splitext(file_path)[1].lower() | |
| if file_extension not in FILE_EXTENSIONS: | |
| raise ValueError(f"Unsupported file format: {file_extension}") | |
| # Select appropriate loader | |
| loader_class = FILE_EXTENSIONS[file_extension] | |
| loader = loader_class(file_path) | |
| # Load and split the document | |
| documents = loader.load() | |
| chunks = self.text_splitter.split_documents(documents) | |
| if not chunks: | |
| raise ValueError("No content extracted from the document") | |
| print(f"Document split into {len(chunks)} chunks") | |
| # Create vector store | |
| vector_store = FAISS.from_documents(chunks, self.embedding_model) | |
| return vector_store | |
| def setup_retrieval_chain(self, vector_store): | |
| """Set up the retrieval chain with the vector store.""" | |
| retriever = vector_store.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={ | |
| "k": TOP_K, | |
| "score_threshold": THRESHOLD | |
| } | |
| ) | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm=self.llm, | |
| retriever=retriever, | |
| return_source_documents=True, | |
| verbose=True | |
| ) | |
| return chain | |
| def get_answer(self, question, session_id, vector_store, chat_history): | |
| """Get answer for a question using the retrieval chain.""" | |
| if not question.strip(): | |
| return "Please enter a question related to the document.", chat_history | |
| # Setup retrieval chain if needed | |
| retrieval_chain = self.setup_retrieval_chain(vector_store) | |
| # Format chat history for the model | |
| formatted_chat_history = [(q, a) for q, a in chat_history] | |
| # Get response from the chain | |
| response = retrieval_chain( | |
| {"question": question, "chat_history": formatted_chat_history} | |
| ) | |
| answer = response["answer"] | |
| source_documents = response.get("source_documents", []) | |
| # Format answer with source information | |
| if source_documents: | |
| source_info = "\n\nSources:" | |
| seen_sources = set() | |
| for doc in source_documents: | |
| source = doc.metadata.get("source", "Unknown source") | |
| page = doc.metadata.get("page", "Unknown page") | |
| source_key = f"{source}-{page}" | |
| if source_key not in seen_sources: | |
| seen_sources.add(source_key) | |
| if source == "Unknown source": | |
| source_info += f"\n- Document chunk (page {page})" | |
| else: | |
| source_info += f"\n- {os.path.basename(source)} (page {page})" | |
| answer += source_info | |
| return answer, chat_history + [(question, answer)] | |
| def generate_session_id(): | |
| """Generate a unique session ID.""" | |
| import uuid | |
| return str(uuid.uuid4()) | |
| def process_uploaded_document(file_path): | |
| """Process an uploaded document and set up the session.""" | |
| global current_session_id, current_document_store, current_document_name, conversation_history | |
| try: | |
| if file_path is None: | |
| return None, "Please upload a document first." | |
| # In newer Gradio versions, the file input with type="filepath" returns the path directly | |
| # No need to save the file as it's already saved by Gradio | |
| # Extract filename for display | |
| filename = os.path.basename(file_path) | |
| # Create document AI bot if not already created | |
| if not hasattr(process_uploaded_document, "bot"): | |
| process_uploaded_document.bot = DocumentAIBot() | |
| # Process the document | |
| vector_store = process_uploaded_document.bot.process_document(file_path) | |
| # Create a new session | |
| session_id = generate_session_id() | |
| conversation_history[session_id] = [] | |
| # Update global variables | |
| current_session_id = session_id | |
| current_document_store = vector_store | |
| current_document_name = filename | |
| return [], f"Document '{filename}' processed successfully. You can now ask questions about it." | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"Error processing document: {str(e)}" | |
| def clear_conversation(): | |
| """Clear the conversation history for the current session.""" | |
| global conversation_history, current_session_id | |
| if current_session_id and current_session_id in conversation_history: | |
| conversation_history[current_session_id] = [] | |
| return [], f"Conversation cleared. You can continue asking questions about '{current_document_name}'." | |
| def answer_question(question, history): | |
| """Answer a question about the current document.""" | |
| global current_session_id, current_document_store, conversation_history | |
| if not current_document_store: | |
| return "", history + [(question, "Please upload a document first.")] | |
| if not hasattr(process_uploaded_document, "bot"): | |
| return "", history + [(question, "Document AI bot not initialized. Please reload the page and try again.")] | |
| try: | |
| # Get current chat history | |
| chat_history = conversation_history.get(current_session_id, []) | |
| # Get answer | |
| answer, updated_history = process_uploaded_document.bot.get_answer( | |
| question, | |
| current_session_id, | |
| current_document_store, | |
| chat_history | |
| ) | |
| # Update conversation history | |
| conversation_history[current_session_id] = updated_history | |
| # Update the display history | |
| history = history + [(question, answer)] | |
| return "", history | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return "", history + [(question, f"Error generating answer: {str(e)}")] | |
| def build_interface(): | |
| """Build and launch the Gradio interface.""" | |
| # Define the Gradio blocks | |
| with gr.Blocks(title="Document AI Chatbot") as interface: | |
| gr.Markdown("# 📄 Document AI Chatbot") | |
| gr.Markdown("Upload a document (PDF, TXT, DOCX, PPTX) and ask questions about its content.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Document upload and processing section | |
| file_input = gr.File( | |
| label="Upload Document", | |
| file_types=[".pdf", ".txt", ".docx", ".pptx"], | |
| type="filepath" # This returns the file path directly | |
| ) | |
| upload_button = gr.Button("Process Document", variant="primary") | |
| upload_status = gr.Textbox(label="Upload Status", interactive=False) | |
| clear_button = gr.Button("Clear Conversation") | |
| gr.Markdown("### System Information") | |
| system_info = gr.Markdown(f""" | |
| - Embedding Model: {EMBEDDING_MODEL} | |
| - Language Model: {LLM_MODEL} | |
| - Running on: {DEVICE} | |
| - Chunk Size: {CHUNK_SIZE} | |
| - Relevance Threshold: {THRESHOLD} | |
| """) | |
| with gr.Column(scale=2): | |
| # Chat interface | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=500, | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="Ask a question about the document", | |
| placeholder="What is the main topic of this document?", | |
| lines=2, | |
| show_label=True | |
| ) | |
| submit_button = gr.Button("Submit", variant="primary") | |
| # Set up event handlers | |
| upload_button.click( | |
| process_uploaded_document, | |
| inputs=[file_input], | |
| outputs=[chatbot, upload_status] | |
| ) | |
| submit_button.click( | |
| answer_question, | |
| inputs=[question_input, chatbot], | |
| outputs=[question_input, chatbot] | |
| ) | |
| question_input.submit( | |
| answer_question, | |
| inputs=[question_input, chatbot], | |
| outputs=[question_input, chatbot] | |
| ) | |
| clear_button.click( | |
| clear_conversation, | |
| inputs=[], | |
| outputs=[chatbot, upload_status] | |
| ) | |
| return interface | |
| # Main execution | |
| if __name__ == "__main__": | |
| demo = build_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |