Spaces:
Sleeping
Sleeping
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from chromadb.config import Settings | |
| from transformers import pipeline | |
| import streamlit as st | |
| import fitz | |
| from PIL import Image | |
| config = Settings( | |
| persist_directory="./chromadb_data", | |
| chroma_db_impl="sqlite", | |
| ) | |
| def setup_chromadb(): | |
| client = chromadb.PersistentClient(path="./chromadb_data") | |
| collection = client.get_or_create_collection( | |
| name="pdf_data", | |
| embedding_function=chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ), | |
| ) | |
| return client, collection | |
| def clear_collection(client, collection_name): | |
| client.delete_collection(name=collection_name) | |
| return client.get_or_create_collection( | |
| name=collection_name, | |
| embedding_function=chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ), | |
| ) | |
| def extract_text_from_pdf(uploaded_file): | |
| with fitz.open(stream=uploaded_file.read(), filetype="pdf") as doc: | |
| text = "" | |
| for page in doc: | |
| text += page.get_text() | |
| return text | |
| def add_pdf_text_to_db(collection, pdf_text): | |
| sentences = pdf_text.split("\n") | |
| for idx, sentence in enumerate(sentences): | |
| if sentence.strip(): | |
| collection.add( | |
| ids=[f"pdf_text_{idx}"], | |
| documents=[sentence], | |
| metadatas={"line_number": idx, "text": sentence} | |
| ) | |
| def query_pdf_data(collection, query, retriever_model): | |
| results = collection.query( | |
| query_texts=[query], | |
| n_results=3 | |
| ) | |
| context = " ".join([doc for doc in results["documents"][0]]) | |
| answer = retriever_model(f"Context: {context}\nQuestion: {query}") | |
| return answer, results["metadatas"] | |
| def main(): | |
| image = Image.open('LOGO.PNG') | |
| st.image( | |
| image, width=250) | |
| st.title("PDF Chatbot with RAG") | |
| st.markdown("Google Flan-T5-Small + ChromaDB") | |
| st.header('', divider='rainbow') | |
| st.write("Upload a PDF, and ask questions about its content!") | |
| client, collection = setup_chromadb() | |
| retriever_model = pipeline("text2text-generation", model="google/flan-t5-small") | |
| # File upload | |
| uploaded_file = st.file_uploader("Upload your PDF file", type="pdf") | |
| if uploaded_file: | |
| try: | |
| collection = clear_collection(client, "pdf_data") | |
| st.info("Existing data cleared from the database.") | |
| pdf_text = extract_text_from_pdf(uploaded_file) | |
| st.success("Text extracted successfully!") | |
| st.text_area("Extracted Text:", pdf_text, height=300) | |
| add_pdf_text_to_db(collection, pdf_text) | |
| st.success("PDF text has been added to the database. You can now query it!") | |
| except Exception as e: | |
| st.error(f"Error extracting text: {e}") | |
| query = st.text_input("Enter your query about the PDF:") | |
| if query: | |
| try: | |
| answer, metadata = query_pdf_data(collection, query, retriever_model) | |
| st.subheader("Answer:") | |
| st.write(answer[0]['generated_text']) | |
| st.subheader("Retrieved Context:") | |
| st.write(answer) | |
| for meta in metadata[0]: | |
| st.write(meta) | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| if __name__ == "__main__": | |
| main() | |