Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import time | |
| from datetime import datetime | |
| from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from pptx import Presentation | |
| from io import BytesIO | |
| import shutil | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Environment setup for Hugging Face token | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "default-token") | |
| if os.environ["HUGGINGFACEHUB_API_TOKEN"] == "default-token": | |
| logger.warning("HUGGINGFACEHUB_API_TOKEN not set. Some models may not work.") | |
| # Model and embedding options | |
| LLM_MODELS = { | |
| "Lightweight (Gemma-2B)": "google/gemma-2b-it", | |
| "Balanced (Mixtral-8x7B)": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "High Accuracy (Llama-3-8B)": "meta-llama/Llama-3-8b-hf" | |
| } | |
| EMBEDDING_MODELS = { | |
| "Lightweight (MiniLM-L6)": "sentence-transformers/all-MiniLM-L6-v2", | |
| "Balanced (MPNet-Base)": "sentence-transformers/all-mpnet-base-v2", | |
| "High Accuracy (BGE-Large)": "BAAI/bge-large-en-v1.5" | |
| } | |
| # Global state | |
| vector_store = None | |
| qa_chain = None | |
| chat_history = [] | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| PERSIST_DIRECTORY = "./chroma_db" | |
| # Custom PPTX loader | |
| class PPTXLoader: | |
| def __init__(self, file_path): | |
| self.file_path = file_path | |
| def load(self): | |
| docs = [] | |
| try: | |
| with open(self.file_path, "rb") as f: | |
| prs = Presentation(BytesIO(f.read())) | |
| for slide_num, slide in enumerate(prs.slides, 1): | |
| text = "" | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text") and shape.text: | |
| text += shape.text + "\n" | |
| if text.strip(): | |
| docs.append({"page_content": text, "metadata": {"source": self.file_path, "slide": slide_num}}) | |
| except Exception as e: | |
| logger.error(f"Error loading PPTX {self.file_path}: {str(e)}") | |
| return [] | |
| return docs | |
| # Function to load documents | |
| def load_documents(files): | |
| documents = [] | |
| for file in files: | |
| try: | |
| file_path = file.name | |
| logger.info(f"Loading file: {file_path}") | |
| if file_path.endswith(".pdf"): | |
| loader = PyPDFLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif file_path.endswith(".txt"): | |
| loader = TextLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif file_path.endswith(".docx"): | |
| loader = Docx2txtLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif file_path.endswith(".pptx"): | |
| loader = PPTXLoader(file_path) | |
| documents.extend([{"page_content": doc["page_content"], "metadata": doc["metadata"]} for doc in loader.load()]) | |
| except Exception as e: | |
| logger.error(f"Error loading file {file_path}: {str(e)}") | |
| continue | |
| return documents | |
| # Function to process documents and create vector store | |
| def process_documents(files, chunk_size, chunk_overlap, embedding_model): | |
| global vector_store | |
| if not files: | |
| return "Please upload at least one document.", None | |
| # Clear existing vector store to avoid dimensionality mismatch | |
| if os.path.exists(PERSIST_DIRECTORY): | |
| try: | |
| shutil.rmtree(PERSIST_DIRECTORY) | |
| logger.info("Cleared existing ChromaDB directory.") | |
| except Exception as e: | |
| logger.error(f"Error clearing ChromaDB directory: {str(e)}") | |
| return f"Error clearing vector store: {str(e)}", None | |
| # Load documents | |
| documents = load_documents(files) | |
| if not documents: | |
| return "No valid documents loaded. Check file formats or content.", None | |
| # Split documents | |
| try: | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=int(chunk_size), | |
| chunk_overlap=int(chunk_overlap), | |
| length_function=len | |
| ) | |
| doc_splits = text_splitter.split_documents(documents) | |
| logger.info(f"Split {len(documents)} documents into {len(doc_splits)} chunks.") | |
| except Exception as e: | |
| logger.error(f"Error splitting documents: {str(e)}") | |
| return f"Error splitting documents: {str(e)}", None | |
| # Create embeddings | |
| try: | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODELS[embedding_model]) | |
| except Exception as e: | |
| logger.error(f"Error initializing embeddings for {embedding_model}: {str(e)}") | |
| return f"Error initializing embeddings: {str(e)}", None | |
| # Create vector store | |
| try: | |
| vector_store = Chroma.from_documents(doc_splits, embeddings, persist_directory=PERSIST_DIRECTORY) | |
| return f"Processed {len(documents)} documents into {len(doc_splits)} chunks.", None | |
| except Exception as e: | |
| logger.error(f"Error creating vector store: {str(e)}") | |
| return f"Error creating vector store: {str(e)}", None | |
| # Function to initialize QA chain | |
| def initialize_qa_chain(llm_model, temperature): | |
| global qa_chain | |
| if not vector_store: | |
| return "Please process documents first.", None | |
| try: | |
| llm = HuggingFaceEndpoint( | |
| repo_id=LLM_MODELS[llm_model], | |
| task="text-generation", | |
| temperature=float(temperature), | |
| max_new_tokens=512, | |
| huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
| ) | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=vector_store.as_retriever(search_kwargs={"k": 3}), | |
| memory=memory | |
| ) | |
| logger.info(f"Initialized QA chain with {llm_model}.") | |
| return "QA chain initialized successfully.", None | |
| except Exception as e: | |
| logger.error(f"Error initializing QA chain for {llm_model}: {str(e)}") | |
| return f"Error initializing QA chain: {str(e)}. Ensure your HF token has access to {llm_model}.", None | |
| # Function to handle user query | |
| def answer_question(question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap): | |
| global chat_history | |
| if not vector_store: | |
| return "Please process documents first.", chat_history | |
| if not qa_chain: | |
| return "Please initialize the QA chain.", chat_history | |
| if not question.strip(): | |
| return "Please enter a valid question.", chat_history | |
| try: | |
| response = qa_chain({"question": question})["answer"] | |
| chat_history.append(("User", question)) | |
| chat_history.append(("Bot", response)) | |
| logger.info(f"Answered question: {question}") | |
| return response, chat_history | |
| except Exception as e: | |
| logger.error(f"Error answering question: {str(e)}") | |
| return f"Error answering question: {str(e)}", chat_history | |
| # Function to export chat history | |
| def export_chat(): | |
| if not chat_history: | |
| return "No chat history to export.", None | |
| try: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"chat_history_{timestamp}.txt" | |
| with open(filename, "w") as f: | |
| for role, message in chat_history: | |
| f.write(f"{role}: {message}\n\n") | |
| logger.info(f"Exported chat history to {filename}.") | |
| return f"Chat history exported to {filename}.", filename | |
| except Exception as e: | |
| logger.error(f"Error exporting chat history: {str(e)}") | |
| return f"Error exporting chat history: {str(e)}", None | |
| # Function to reset the app | |
| def reset_app(): | |
| global vector_store, qa_chain, chat_history, memory | |
| try: | |
| vector_store = None | |
| qa_chain = None | |
| chat_history = [] | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
| if os.path.exists(PERSIST_DIRECTORY): | |
| shutil.rmtree(PERSIST_DIRECTORY) | |
| logger.info("Cleared ChromaDB directory on reset.") | |
| logger.info("App reset successfully.") | |
| return "App reset successfully.", None | |
| except Exception as e: | |
| logger.error(f"Error resetting app: {str(e)}") | |
| return f"Error resetting app: {str(e)}", None | |
| # Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo: | |
| gr.Markdown("# DocTalk: Document Q&A Chatbot") | |
| gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX), select models, tune parameters, and ask questions!") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| file_upload = gr.Files(label="Upload Documents", file_types=[".pdf", ".txt", ".docx", ".pptx"]) | |
| with gr.Row(): | |
| process_button = gr.Button("Process Documents") | |
| reset_button = gr.Button("Reset App") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=1): | |
| llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="Lightweight (Gemma-2B)") | |
| embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)") | |
| temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature") | |
| chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size") | |
| chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap") | |
| init_button = gr.Button("Initialize QA Chain") | |
| gr.Markdown("## Chat Interface") | |
| question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...") | |
| answer = gr.Textbox(label="Answer", interactive=False) | |
| chat_display = gr.Chatbot(label="Chat History") | |
| export_button = gr.Button("Export Chat History") | |
| export_file = gr.File(label="Exported Chat File") | |
| # Event handlers | |
| process_button.click( | |
| fn=process_documents, | |
| inputs=[file_upload, chunk_size, chunk_overlap, embedding_model], | |
| outputs=[status, chat_display] | |
| ) | |
| init_button.click( | |
| fn=initialize_qa_chain, | |
| inputs=[llm_model, temperature], | |
| outputs=[status, chat_display] | |
| ) | |
| question.submit( | |
| fn=answer_question, | |
| inputs=[question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap], | |
| outputs=[answer, chat_display] | |
| ) | |
| export_button.click( | |
| fn=export_chat, | |
| outputs=[status, export_file] | |
| ) | |
| reset_button.click( | |
| fn=reset_app, | |
| outputs=[status, chat_display] | |
| ) | |
| demo.launch() |