Spaces:
Sleeping
Sleeping
| from typing import Any, List, Tuple | |
| import gradio as gr | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain_openai import ChatOpenAI | |
| from langchain_community.document_loaders import PyMuPDFLoader | |
| import fitz | |
| from PIL import Image | |
| import os | |
| import re | |
| import openai | |
| # MyApp class to handle the processes | |
| class MyApp: | |
| def __init__(self) -> None: | |
| self.OPENAI_API_KEY: str = None # Initialize with None | |
| self.chain = None | |
| self.chat_history: list = [] | |
| self.documents = None | |
| self.file_name = None | |
| def set_api_key(self, api_key: str): | |
| self.OPENAI_API_KEY = api_key | |
| openai.api_key = api_key | |
| def process_file(self, file) -> Image.Image: | |
| loader = PyMuPDFLoader(file.name) | |
| self.documents = loader.load() | |
| self.file_name = os.path.basename(file.name) | |
| doc = fitz.open(file.name) | |
| page = doc[0] | |
| pix = page.get_pixmap(dpi=150) | |
| image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| return image | |
| def build_chain(self, file) -> str: | |
| embeddings = OpenAIEmbeddings(openai_api_key=self.OPENAI_API_KEY) | |
| pdfsearch = Chroma.from_documents( | |
| self.documents, | |
| embeddings, | |
| collection_name=self.file_name, | |
| ) | |
| self.chain = ConversationalRetrievalChain.from_llm( | |
| ChatOpenAI(temperature=0.0, openai_api_key=self.OPENAI_API_KEY), | |
| retriever=pdfsearch.as_retriever(search_kwargs={"k": 1}), | |
| return_source_documents=True, | |
| ) | |
| return "Vector database built successfully!" | |
| # Function to add text to chat history | |
| def add_text(history: List[Tuple[str, str]], text: str) -> List[Tuple[str, str]]: | |
| if not text: | |
| raise gr.Error("Enter text") | |
| history.append((text, "")) | |
| return history | |
| # Function to get response from the model | |
| def get_response(history, query): | |
| if app.chain is None: | |
| raise gr.Error("The chain has not been built yet. Please ensure the vector database is built before querying.") | |
| try: | |
| result = app.chain.invoke( | |
| {"question": query, "chat_history": app.chat_history} | |
| ) | |
| app.chat_history.append((query, result["answer"])) | |
| source_docs = result["source_documents"] | |
| source_texts = [] | |
| for doc in source_docs: | |
| source_texts.append(f"Page {doc.metadata['page'] + 1}: {doc.page_content}") | |
| source_texts_str = "\n\n".join(source_texts) | |
| history[-1] = (history[-1][0], result["answer"]) | |
| return history, source_texts_str | |
| except Exception as e: | |
| app.chat_history.append((query, "I have no information about it. Feed me knowledge, please!")) | |
| return history, f"I have no information about it. Feed me knowledge, please! Error: {str(e)}" | |
| # Function to get response for the current RAG tab | |
| def get_response_current(history, query): | |
| if app.chain is None: | |
| raise gr.Error("The chain has not been built yet. Please ensure the vector database is built before querying.") | |
| try: | |
| result = app.chain.invoke( | |
| {"question": query, "chat_history": app.chat_history} | |
| ) | |
| app.chat_history.append((query, result["answer"])) | |
| source_docs = result["source_documents"] | |
| source_texts = [] | |
| for doc in source_docs: | |
| source_texts.append(f"Page {doc.metadata['page'] + 1}: {doc.page_content}") | |
| source_texts_str = "\n\n".join(source_texts) | |
| history[-1] = (history[-1][0], result["answer"]) | |
| return history, source_texts_str | |
| except Exception as e: | |
| app.chat_history.append((query, "I have no information about it. Feed me knowledge, please!")) | |
| return history, f"I have no information about it. Feed me knowledge, please! Error: {str(e)}" | |
| # Function to render file | |
| def render_file(file) -> Image.Image: | |
| doc = fitz.open(file.name) | |
| page = doc[0] | |
| pix = page.get_pixmap(dpi=150) | |
| image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| return image | |
| # Function to purge chat and render first page of PDF | |
| def purge_chat_and_render_first(file) -> Image.Image: | |
| app.chat_history = [] | |
| doc = fitz.open(file.name) | |
| page = doc[0] | |
| pix = page.get_pixmap(dpi=150) | |
| image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| return image | |
| # Function to refresh chat | |
| def refresh_chat(): | |
| app.chat_history = [] | |
| return [] | |
| app = MyApp() | |
| # Function to set API key | |
| def set_api_key(api_key): | |
| app.set_api_key(api_key) | |
| # Pre-process the saved PDF file after setting the API key | |
| saved_file_path = "THEDIA1.pdf" | |
| with open(saved_file_path, 'rb') as saved_file: | |
| app.process_file(saved_file) | |
| app.build_chain(saved_file) | |
| return f"API Key set to {api_key[:4]}...{api_key[-4:]} and vector database built successfully!" | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| title = "🧘♀️ Dialectical Behaviour Therapy" | |
| api_key_input = gr.Textbox(label="OpenAI API Key", type="password", placeholder="Enter your OpenAI API Key") | |
| api_key_btn = gr.Button("Set API Key") | |
| api_key_status = gr.Textbox(value="API Key status", interactive=False) | |
| api_key_btn.click( | |
| fn=set_api_key, | |
| inputs=[api_key_input], | |
| outputs=[api_key_status] | |
| ) | |
| with gr.Tab("Take a Dialectical Behaviour Therapy with Me"): | |
| with gr.Column(): | |
| chatbot_current = gr.Chatbot(elem_id="chatbot_current") | |
| txt_current = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter text and press submit", | |
| scale=2 | |
| ) | |
| submit_btn_current = gr.Button("Submit", scale=1) | |
| refresh_btn_current = gr.Button("Refresh Chat", scale=1) | |
| source_texts_output_current = gr.Textbox(label="Source Texts", interactive=False) | |
| submit_btn_current.click( | |
| fn=add_text, | |
| inputs=[chatbot_current, txt_current], | |
| outputs=[chatbot_current], | |
| queue=False, | |
| ).success( | |
| fn=get_response_current, inputs=[chatbot_current, txt_current], outputs=[chatbot_current, source_texts_output_current] | |
| ) | |
| refresh_btn_current.click( | |
| fn=refresh_chat, | |
| inputs=[], | |
| outputs=[chatbot_current], | |
| ) | |
| demo.queue() | |
| demo.launch() | |