Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import time | |
| import pandas as pd | |
| import sqlite3 | |
| import logging | |
| import requests # for HTTP calls to Gemini | |
| from langchain.document_loaders import OnlinePDFLoader # for loading PDF text | |
| from langchain.embeddings import HuggingFaceEmbeddings # open source embedding model | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma # vectorization from langchain_community | |
| from langchain.chains import RetrievalQA # for QA chain | |
| from langchain_core.prompts import PromptTemplate # prompt template import | |
| # ------------------------------ | |
| # Gemini API Wrapper | |
| # ------------------------------ | |
| class ChatGemini: | |
| def __init__(self, api_key, temperature=0, model_name="gemini-2.0-flash"): | |
| self.api_key = api_key | |
| self.temperature = temperature | |
| self.model_name = model_name | |
| def generate(self, prompt): | |
| url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={self.api_key}" | |
| payload = { | |
| "contents": [{ | |
| "parts": [{"text": prompt}] | |
| }] | |
| } | |
| headers = {"Content-Type": "application/json"} | |
| response = requests.post(url, json=payload, headers=headers) | |
| if response.status_code != 200: | |
| raise Exception(f"Gemini API error: {response.status_code} - {response.text}") | |
| data = response.json() | |
| candidate = data.get("candidates", [{}])[0] | |
| return candidate.get("output", {}).get("text", "No output from Gemini API") | |
| def __call__(self, prompt, **kwargs): | |
| return self.generate(prompt) | |
| # ------------------------------ | |
| # Setup Logging | |
| # ------------------------------ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| log_messages = "" # global log collector | |
| def update_log(message): | |
| global log_messages | |
| log_messages += message + "\n" | |
| logger.info(message) | |
| # ------------------------------ | |
| # PDF Embedding & QA Chain (No OCR) | |
| # ------------------------------ | |
| def load_pdf_and_generate_embeddings(pdf_doc, gemini_api_key, relevant_pages): | |
| try: | |
| # Use the PDF file's path to extract text. | |
| pdf_path = pdf_doc.name | |
| loader = OnlinePDFLoader(pdf_path) | |
| pages = loader.load_and_split() | |
| update_log(f"Extracted text from {len(pages)} pages in {pdf_path}") | |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| pages_to_be_loaded = [] | |
| if relevant_pages: | |
| for page in relevant_pages.split(","): | |
| if page.strip().isdigit(): | |
| pageIndex = int(page.strip()) - 1 | |
| if 0 <= pageIndex < len(pages): | |
| pages_to_be_loaded.append(pages[pageIndex]) | |
| if not pages_to_be_loaded: | |
| pages_to_be_loaded = pages.copy() | |
| update_log("No specific pages selected; using entire PDF.") | |
| vectordb = Chroma.from_documents(pages_to_be_loaded, embedding=embeddings) | |
| prompt_template = ( | |
| """Use the following context to answer the question. If you do not know the answer, return N/A. | |
| {context} | |
| Question: {question} | |
| Return the answer in JSON format.""" | |
| ) | |
| PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
| chain_type_kwargs = {"prompt": PROMPT} | |
| global pdf_qa | |
| pdf_qa = RetrievalQA.from_chain_type( | |
| llm=ChatGemini(api_key=gemini_api_key, temperature=0, model_name="gemini-2.0-flash"), | |
| chain_type="stuff", | |
| retriever=vectordb.as_retriever(search_kwargs={"k": 5}), | |
| chain_type_kwargs=chain_type_kwargs, | |
| return_source_documents=False | |
| ) | |
| update_log("PDF embeddings generated and QA chain initialized using Gemini.") | |
| return "Ready" | |
| except Exception as e: | |
| update_log(f"Error in load_pdf_and_generate_embeddings: {str(e)}") | |
| return f"Error: {str(e)}" | |
| # ------------------------------ | |
| # SQLite Question Set Functions | |
| # ------------------------------ | |
| def create_db_connection(): | |
| DB_FILE = "./questionset.db" | |
| connection = sqlite3.connect(DB_FILE, check_same_thread=False) | |
| return connection | |
| def create_sqlite_table(connection): | |
| update_log("Creating/Verifying SQLite table for questions.") | |
| cursor = connection.cursor() | |
| try: | |
| cursor.execute('SELECT * FROM questions') | |
| cursor.fetchall() | |
| except sqlite3.OperationalError: | |
| cursor.execute( | |
| ''' | |
| CREATE TABLE questions (document_type TEXT NOT NULL, questionset_tag TEXT NOT NULL, field TEXT NOT NULL, question TEXT NOT NULL) | |
| ''' | |
| ) | |
| update_log("Questions table created.") | |
| connection.commit() | |
| def load_master_questionset_into_sqlite(connection): | |
| create_sqlite_table(connection) | |
| cursor = connection.cursor() | |
| masterlist_count = cursor.execute( | |
| "SELECT COUNT(document_type) FROM questions WHERE document_type=? AND questionset_tag=?", | |
| ("DOC_A", "masterlist",) | |
| ).fetchone()[0] | |
| if masterlist_count == 0: | |
| update_log("Loading masterlist into DB.") | |
| fields, queries = create_field_and_question_list_for_DOC_A() | |
| for i in range(len(queries)): | |
| cursor.execute( | |
| "INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)", | |
| ["DOC_A", "masterlist", fields[i], queries[i]] | |
| ) | |
| fields2, queries2 = create_field_and_question_list_for_DOC_B() | |
| for i in range(len(queries2)): | |
| cursor.execute( | |
| "INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)", | |
| ["DOC_B", "masterlist", fields2[i], queries2[i]] | |
| ) | |
| connection.commit() | |
| total_questions = cursor.execute("SELECT COUNT(document_type) FROM questions").fetchone()[0] | |
| update_log(f"Total questions in DB: {total_questions}") | |
| def create_field_and_question_list_for_DOC_A(): | |
| # Two sample entries for DOC_A | |
| fields = ["Loan Number", "Borrower"] | |
| queries = ["What is the Loan Number?", "Who is the Borrower?"] | |
| return fields, queries | |
| def create_field_and_question_list_for_DOC_B(): | |
| # Two sample entries for DOC_B | |
| fields = ["Property Address", "Signed Date"] | |
| queries = ["What is the Property Address?", "What is the Signed Date?"] | |
| return fields, queries | |
| def retrieve_document_type_and_questionsettag_from_sqlite(): | |
| connection = create_db_connection() | |
| load_master_questionset_into_sqlite(connection) | |
| cursor = connection.cursor() | |
| rows = cursor.execute("SELECT document_type, questionset_tag FROM questions ORDER BY document_type, UPPER(questionset_tag)").fetchall() | |
| choices = [] | |
| for row in rows: | |
| value = f"{row[0]}:{row[1]}" | |
| if value not in choices: | |
| choices.append(value) | |
| update_log(f"Found question set: {value}") | |
| connection.close() | |
| return gr.Dropdown.update(choices=choices, value=choices[0] if choices else "") | |
| def retrieve_fields_and_questions(dropdownoption): | |
| splitwords = dropdownoption.split(":") | |
| connection = create_db_connection() | |
| cursor = connection.cursor() | |
| rows = cursor.execute( | |
| "SELECT document_type, field, question FROM questions WHERE document_type=? AND questionset_tag=?", | |
| (splitwords[0], splitwords[1],) | |
| ).fetchall() | |
| connection.close() | |
| return pd.DataFrame(rows, columns=["documentType", "field", "question"]) | |
| def add_questionset(data, document_type, tag_for_questionset): | |
| connection = create_db_connection() | |
| create_sqlite_table(connection) | |
| cursor = connection.cursor() | |
| for _, row in data.iterrows(): | |
| cursor.execute( | |
| "INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)", | |
| [document_type, tag_for_questionset, row['field'], row['question']] | |
| ) | |
| connection.commit() | |
| connection.close() | |
| def load_csv_and_store_questionset_into_sqlite(csv_file, document_type, tag_for_questionset): | |
| if tag_for_questionset and document_type: | |
| data = pd.read_csv(csv_file.name) | |
| add_questionset(data, document_type, tag_for_questionset) | |
| response = f"Uploaded {data.shape[0]} fields and questions for {document_type}:{tag_for_questionset}" | |
| update_log(response) | |
| return response | |
| else: | |
| return "Please select a Document Type and provide a name for the Question Set" | |
| def answer_predefined_questions(document_type_and_questionset): | |
| splitwords = document_type_and_questionset.split(":") | |
| document_type = splitwords[0] | |
| question_set = splitwords[1] | |
| fields, questions, responses = [], [], [] | |
| connection = create_db_connection() | |
| cursor = connection.cursor() | |
| rows = cursor.execute( | |
| "SELECT field, question FROM questions WHERE document_type=? AND questionset_tag=?", | |
| (document_type, question_set) | |
| ).fetchall() | |
| connection.close() | |
| for field, question in rows: | |
| fields.append(field) | |
| questions.append(question) | |
| try: | |
| responses.append(pdf_qa.run(question)) | |
| except Exception as e: | |
| err = f"Error: {str(e)}" | |
| update_log(err) | |
| responses.append(err) | |
| return pd.DataFrame({"Field": fields, "Question": questions, "Response": responses}) | |
| def summarize_contents(): | |
| question = "Generate a short summary of the contents along with up to 3 example questions." | |
| if 'pdf_qa' not in globals(): | |
| return "Error: PDF embeddings not generated. Load a PDF first." | |
| try: | |
| response = pdf_qa.run(question) | |
| update_log("Summarization successful.") | |
| return response | |
| except Exception as e: | |
| err = f"Error in summarization: {str(e)}" | |
| update_log(err) | |
| return err | |
| def answer_query(query): | |
| if 'pdf_qa' not in globals(): | |
| return "Error: PDF embeddings not generated. Load a PDF first." | |
| try: | |
| response = pdf_qa.run(query) | |
| update_log(f"Query answered: {query}") | |
| return response | |
| except Exception as e: | |
| err = f"Error in answering query: {str(e)}" | |
| update_log(err) | |
| return err | |
| def get_log(): | |
| return log_messages | |
| # ------------------------------ | |
| # Gradio Interface | |
| # ------------------------------ | |
| css = """ | |
| #col-container {max-width: 700px; margin: auto;} | |
| """ | |
| title = """ | |
| <div style="text-align: center;"> | |
| <h1>AskMoli - Chatbot for PDFs</h1> | |
| <p>Upload a PDF and generate embeddings. Then ask questions or use a predefined set.</p> | |
| </div> | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML(title) | |
| with gr.Tab("Chatbot"): | |
| with gr.Column(): | |
| gemini_api_key = gr.Textbox(label="Your Gemini API Key", type="password") | |
| pdf_doc = gr.File(label="Load a PDF", file_types=['.pdf'], type='filepath') | |
| relevant_pages = gr.Textbox(label="Optional: Comma separated page numbers") | |
| with gr.Row(): | |
| status = gr.Textbox(label="Status", interactive=False) | |
| load_pdf_btn = gr.Button("Upload PDF & Generate Embeddings") | |
| with gr.Row(): | |
| summary = gr.Textbox(label="Summary") | |
| summarize_pdf_btn = gr.Button("Summarize Contents") | |
| with gr.Row(): | |
| input_query = gr.Textbox(label="Your Question") | |
| output_answer = gr.Textbox(label="Answer") | |
| submit_query_btn = gr.Button("Submit Question") | |
| with gr.Row(): | |
| questionsets = gr.Dropdown(label="Pre-defined Question Sets", choices=[]) | |
| load_questionsets_btn = gr.Button("Retrieve Sets") | |
| fields_and_questions = gr.Dataframe(label="Fields & Questions") | |
| load_fields_btn = gr.Button("Retrieve Questions") | |
| with gr.Row(): | |
| answers_df = gr.Dataframe(label="Pre-defined Answers") | |
| answer_predefined_btn = gr.Button("Get Answers") | |
| log_window = gr.Textbox(label="Log Window", interactive=False, lines=10) | |
| with gr.Tab("Text Extractor"): | |
| with gr.Column(): | |
| image_pdf = gr.File(label="Load PDF for Text Extraction", file_types=['.pdf'], type='filepath') | |
| with gr.Row(): | |
| extracted_text = gr.Textbox(label="Extracted Text", lines=10) | |
| extract_btn = gr.Button("Extract Text") | |
| def extract_text(pdf_file): | |
| try: | |
| loader = OnlinePDFLoader(pdf_file.name) | |
| docs = loader.load_and_split() | |
| text = "\n".join([doc.page_content for doc in docs]) | |
| update_log(f"Extracted text from {len(docs)} pages.") | |
| return text | |
| except Exception as e: | |
| err = f"Error extracting text: {str(e)}" | |
| update_log(err) | |
| return err | |
| extract_btn.click(extract_text, inputs=image_pdf, outputs=extracted_text) | |
| with gr.Tab("Upload Question Set"): | |
| with gr.Column(): | |
| document_type_for_questionset = gr.Dropdown(choices=["DOC_A", "DOC_B"], label="Select Document Type") | |
| tag_for_questionset = gr.Textbox(label="Name for Question Set (e.g., basic-set)") | |
| csv_file = gr.File(label="Load CSV (fields,question)", file_types=['.csv'], type='filepath') | |
| with gr.Row(): | |
| status_for_csv = gr.Textbox(label="Status", interactive=False) | |
| load_csv_btn = gr.Button("Upload CSV into DB") | |
| refresh_log_btn = gr.Button("Refresh Log") | |
| refresh_log_btn.click(get_log, outputs=log_window) | |
| load_pdf_btn.click(load_pdf_and_generate_embeddings, inputs=[pdf_doc, gemini_api_key, relevant_pages], outputs=status) | |
| summarize_pdf_btn.click(summarize_contents, outputs=summary) | |
| submit_query_btn.click(answer_query, inputs=input_query, outputs=output_answer) | |
| load_questionsets_btn.click(retrieve_document_type_and_questionsettag_from_sqlite, outputs=questionsets) | |
| load_fields_btn.click(retrieve_fields_and_questions, inputs=questionsets, outputs=fields_and_questions) | |
| answer_predefined_btn.click(answer_predefined_questions, inputs=questionsets, outputs=answers_df) | |
| load_csv_btn.click(load_csv_and_store_questionset_into_sqlite, inputs=[csv_file, document_type_for_questionset, tag_for_questionset], outputs=status_for_csv) | |
| demo.launch(debug=True) | |