Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import PyPDF2 | |
| import io | |
| from together import Together | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain.docstore.document import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.llms.base import LLM | |
| from typing import List, Optional | |
| import traceback | |
| # --------------------------- | |
| # WRAP TOGETHER API AS LLM | |
| # --------------------------- | |
| class TogetherLLM(LLM): | |
| client: Together = None | |
| model: str = "meta-llama/Llama-3.3-70B-Instruct-Turbo" | |
| temperature: float = 0.3 | |
| max_tokens: int = 1000 | |
| def __init__(self, client, model="meta-llama/Llama-3.3-70B-Instruct-Turbo", temperature=0.3, max_tokens=1000, **kwargs): | |
| super().__init__(**kwargs) | |
| object.__setattr__(self, 'client', client) | |
| object.__setattr__(self, 'model', model) | |
| object.__setattr__(self, 'temperature', temperature) | |
| object.__setattr__(self, 'max_tokens', max_tokens) | |
| def _llm_type(self) -> str: | |
| return "together-llm" | |
| def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=self.max_tokens, | |
| temperature=self.temperature, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| class Config: | |
| arbitrary_types_allowed = True | |
| # --------------------------- | |
| # PDF TEXT EXTRACTION | |
| # --------------------------- | |
| def extract_text_from_pdf(pdf_file): | |
| """Extract text from PDF with page references""" | |
| docs = [] | |
| try: | |
| print("Starting PDF extraction...") | |
| # Handle different input types | |
| if hasattr(pdf_file, 'name'): | |
| # File uploaded through Gradio | |
| with open(pdf_file.name, 'rb') as file: | |
| pdf_content = file.read() | |
| elif hasattr(pdf_file, "read"): | |
| pdf_content = pdf_file.read() | |
| if hasattr(pdf_file, "seek"): | |
| pdf_file.seek(0) | |
| else: | |
| pdf_content = pdf_file | |
| pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_content)) | |
| print(f"PDF has {len(pdf_reader.pages)} pages") | |
| for page_num, page in enumerate(pdf_reader.pages, start=1): | |
| try: | |
| page_text = page.extract_text() | |
| if page_text and page_text.strip(): | |
| docs.append(Document( | |
| page_content=page_text.strip(), | |
| metadata={"page": page_num, "source": "financial_policy"} | |
| )) | |
| print(f"Extracted text from page {page_num}: {len(page_text)} characters") | |
| else: | |
| docs.append(Document( | |
| page_content="[No extractable text found on this page]", | |
| metadata={"page": page_num, "source": "financial_policy"} | |
| )) | |
| except Exception as e: | |
| print(f"Error extracting page {page_num}: {str(e)}") | |
| docs.append(Document( | |
| page_content=f"[Error extracting page {page_num}: {str(e)}]", | |
| metadata={"page": page_num, "source": "financial_policy"} | |
| )) | |
| print(f"Total documents extracted: {len(docs)}") | |
| return docs | |
| except Exception as e: | |
| print(f"Error in PDF extraction: {str(e)}") | |
| traceback.print_exc() | |
| return [Document(page_content=f"Error extracting text: {str(e)}", metadata={"page": -1})] | |
| # --------------------------- | |
| # BUILD KNOWLEDGE BASE (FAISS) | |
| # --------------------------- | |
| def build_vector_db(docs): | |
| """Convert extracted documents into FAISS vector DB""" | |
| try: | |
| print("Building vector database...") | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=100, | |
| separators=["\n\n", "\n", ". ", " ", ""] | |
| ) | |
| split_docs = text_splitter.split_documents(docs) | |
| print(f"Split into {len(split_docs)} chunks") | |
| # Initialize embeddings | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| print("Embeddings model loaded") | |
| # Create FAISS database | |
| db = FAISS.from_documents(split_docs, embeddings) | |
| print("Vector database created successfully") | |
| return db | |
| except Exception as e: | |
| print(f"Error building vector database: {str(e)}") | |
| traceback.print_exc() | |
| return None | |
| # --------------------------- | |
| # CHATBOT PIPELINE | |
| # --------------------------- | |
| def create_chatbot(api_key, db): | |
| """Set up ConversationalRetrievalChain with memory""" | |
| try: | |
| print("Creating chatbot...") | |
| client = Together(api_key=api_key) | |
| llm = TogetherLLM(client=client) | |
| retriever = db.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": 4} | |
| ) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True, | |
| output_key="answer" | |
| ) | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| memory=memory, | |
| return_source_documents=True, | |
| verbose=True, | |
| ) | |
| print("Chatbot created successfully") | |
| return qa_chain | |
| except Exception as e: | |
| print(f"Error creating chatbot: {str(e)}") | |
| traceback.print_exc() | |
| return None | |
| # --------------------------- | |
| # GRADIO APP | |
| # --------------------------- | |
| def create_app(): | |
| with gr.Blocks(title="๐ Financial Policy Document Chatbot", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# ๐ Financial Policy Document Chatbot") | |
| gr.Markdown(""" | |
| Upload a financial policy PDF document and ask questions about its content. | |
| The chatbot will provide answers with page references from the document. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| api_key_input = gr.Textbox( | |
| label="Together API Key", | |
| placeholder="Enter your Together API key here...", | |
| type="password", | |
| ) | |
| pdf_file = gr.File( | |
| label="Upload Financial Policy PDF", | |
| file_types=[".pdf"], | |
| ) | |
| process_button = gr.Button("๐ Process PDF", variant="primary") | |
| status_message = gr.Textbox(label="Status", interactive=False, lines=3) | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Chat with Financial Policy Document", height=500) | |
| with gr.Row(): | |
| question = gr.Textbox( | |
| label="Ask a question about the document", | |
| placeholder="Example: What is the budget allocation for infrastructure?", | |
| lines=2, | |
| scale=4 | |
| ) | |
| submit_button = gr.Button("๐ Ask", variant="secondary", scale=1) | |
| gr.Markdown(""" | |
| **Sample Questions:** | |
| - What is the debt policy outlined in the document? | |
| - How much budget is allocated for infrastructure? | |
| - What are the revenue sources mentioned? | |
| - What are the key financial objectives? | |
| """) | |
| # State variables | |
| db_state = gr.State() | |
| qa_chain_state = gr.State() | |
| # Event handlers | |
| def process_pdf_handler(pdf_file, api_key): | |
| try: | |
| if pdf_file is None: | |
| return "โ ๏ธ Please upload a PDF file.", None, None | |
| if not api_key or api_key.strip() == "": | |
| return "โ ๏ธ Please enter your Together API key.", None, None | |
| status_msg = "๐ Processing PDF... This may take a few moments." | |
| yield status_msg, None, None | |
| # Extract text from PDF | |
| docs = extract_text_from_pdf(pdf_file) | |
| if not docs or len(docs) == 0: | |
| yield "โ ๏ธ No text could be extracted from the PDF.", None, None | |
| return | |
| # Check if extraction was successful | |
| valid_docs = [doc for doc in docs if not doc.page_content.startswith("[Error") and not doc.page_content.startswith("[No extractable")] | |
| if len(valid_docs) == 0: | |
| yield "โ ๏ธ No readable text found in the PDF.", None, None | |
| return | |
| status_msg = f"๐ Extracted text from {len(docs)} pages. Building search database..." | |
| yield status_msg, None, None | |
| # Build vector database | |
| db = build_vector_db(docs) | |
| if db is None: | |
| yield "โ ๏ธ Failed to build search database.", None, None | |
| return | |
| status_msg = f"๐ Search database created. Setting up chatbot..." | |
| yield status_msg, None, None | |
| # Create chatbot | |
| qa_chain = create_chatbot(api_key, db) | |
| if qa_chain is None: | |
| yield "โ ๏ธ Failed to create chatbot.", None, None | |
| return | |
| final_status = f"โ Successfully processed PDF with {len(docs)} pages. Ready to answer questions!" | |
| yield final_status, db, qa_chain | |
| except Exception as e: | |
| error_msg = f"โ Error processing PDF: {str(e)}" | |
| print(f"Process PDF Error: {str(e)}") | |
| traceback.print_exc() | |
| yield error_msg, None, None | |
| def chat_handler(user_question, qa_chain, history): | |
| if not user_question or user_question.strip() == "": | |
| return history, history, "" | |
| if qa_chain is None: | |
| return history + [(user_question, "โ ๏ธ Please process a PDF document first.")], history, "" | |
| try: | |
| # Get response from the chain | |
| result = qa_chain({"question": user_question}) | |
| answer = result["answer"] | |
| # Add source references | |
| if "source_documents" in result and result["source_documents"]: | |
| pages = [] | |
| for doc in result["source_documents"]: | |
| if "page" in doc.metadata: | |
| pages.append(doc.metadata["page"]) | |
| if pages: | |
| unique_pages = sorted(set(pages)) | |
| if len(unique_pages) == 1: | |
| answer += f"\n\n๐ **Reference:** Page {unique_pages[0]}" | |
| else: | |
| answer += f"\n\n๐ **References:** Pages {', '.join(map(str, unique_pages))}" | |
| new_history = history + [(user_question, answer)] | |
| return new_history, new_history, "" | |
| except Exception as e: | |
| error_response = f"โ Error processing question: {str(e)}" | |
| print(f"Chat Error: {str(e)}") | |
| traceback.print_exc() | |
| return history + [(user_question, error_response)], history, "" | |
| def clear_input(): | |
| return "" | |
| # Bind events | |
| process_button.click( | |
| fn=process_pdf_handler, | |
| inputs=[pdf_file, api_key_input], | |
| outputs=[status_message, db_state, qa_chain_state], | |
| ) | |
| submit_button.click( | |
| fn=chat_handler, | |
| inputs=[question, qa_chain_state, chatbot], | |
| outputs=[chatbot, chatbot, question], | |
| ) | |
| question.submit( | |
| fn=chat_handler, | |
| inputs=[question, qa_chain_state, chatbot], | |
| outputs=[chatbot, chatbot, question], | |
| ) | |
| return app | |
| # --------------------------- | |
| # MAIN EXECUTION | |
| # --------------------------- | |
| if __name__ == "__main__": | |
| app = create_app() | |
| app.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True | |
| ) |